1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2025-01-02 15:36:41 -05:00

extract all session context before checking conditions

This commit is contained in:
Darien Raymond 2019-02-28 09:28:55 +01:00
parent cc513c1002
commit 0d31a68694
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
5 changed files with 103 additions and 113 deletions

View File

@ -3,16 +3,14 @@
package router
import (
"context"
"strings"
"v2ray.com/core/common/net"
"v2ray.com/core/common/session"
"v2ray.com/core/common/strmatcher"
)
type Condition interface {
Apply(ctx context.Context) bool
Apply(ctx *Context) bool
}
type ConditionChan []Condition
@ -27,7 +25,7 @@ func (v *ConditionChan) Add(cond Condition) *ConditionChan {
return v
}
func (v *ConditionChan) Apply(ctx context.Context) bool {
func (v *ConditionChan) Apply(ctx *Context) bool {
for _, cond := range *v {
if !cond.Apply(ctx) {
return false
@ -84,46 +82,36 @@ func (m *DomainMatcher) ApplyDomain(domain string) bool {
return m.matchers.Match(domain) > 0
}
func (m *DomainMatcher) Apply(ctx context.Context) bool {
outbound := session.OutboundFromContext(ctx)
if outbound == nil || !outbound.Target.IsValid() {
func (m *DomainMatcher) Apply(ctx *Context) bool {
if ctx.Outbound == nil || !ctx.Outbound.Target.IsValid() {
return false
}
dest := outbound.Target
dest := ctx.Outbound.Target
if !dest.Address.Family().IsDomain() {
return false
}
return m.ApplyDomain(dest.Address.Domain())
}
func sourceFromContext(ctx context.Context) net.Destination {
inbound := session.InboundFromContext(ctx)
if inbound == nil {
return net.Destination{}
}
return inbound.Source
}
func targetFromContent(ctx context.Context) net.Destination {
outbound := session.OutboundFromContext(ctx)
if outbound == nil {
return net.Destination{}
}
return outbound.Target
}
func resolvedIPFromContext(ctx context.Context) []net.IP {
outbound := session.OutboundFromContext(ctx)
if outbound == nil {
func getIPsFromSource(ctx *Context) []net.IP {
if ctx.Inbound == nil || !ctx.Inbound.Source.IsValid() {
return nil
}
return outbound.ResolvedIPs
dest := ctx.Inbound.Source
if dest.Address.Family().IsDomain() {
return nil
}
return []net.IP{dest.Address.IP()}
}
func getIPsFromTarget(ctx *Context) []net.IP {
return ctx.GetTargetIPs()
}
type MultiGeoIPMatcher struct {
matchers []*GeoIPMatcher
destFunc func(context.Context) net.Destination
resolvedIPFunc func(context.Context) []net.IP
ipFunc func(*Context) []net.IP
}
func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, error) {
@ -141,30 +129,16 @@ func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, e
}
if onSource {
matcher.destFunc = sourceFromContext
matcher.ipFunc = getIPsFromSource
} else {
matcher.destFunc = targetFromContent
matcher.resolvedIPFunc = resolvedIPFromContext
matcher.ipFunc = getIPsFromTarget
}
return matcher, nil
}
func (m *MultiGeoIPMatcher) Apply(ctx context.Context) bool {
ips := make([]net.IP, 0, 4)
dest := m.destFunc(ctx)
if dest.IsValid() && dest.Address.Family().IsIP() {
ips = append(ips, dest.Address.IP())
}
if m.resolvedIPFunc != nil {
rips := m.resolvedIPFunc(ctx)
if len(rips) > 0 {
ips = append(ips, rips...)
}
}
func (m *MultiGeoIPMatcher) Apply(ctx *Context) bool {
ips := m.ipFunc(ctx)
for _, ip := range ips {
for _, matcher := range m.matchers {
@ -186,12 +160,11 @@ func NewPortMatcher(list *net.PortList) *PortMatcher {
}
}
func (v *PortMatcher) Apply(ctx context.Context) bool {
outbound := session.OutboundFromContext(ctx)
if outbound == nil || !outbound.Target.IsValid() {
func (v *PortMatcher) Apply(ctx *Context) bool {
if ctx.Outbound == nil || !ctx.Outbound.Target.IsValid() {
return false
}
return v.port.Contains(outbound.Target.Port)
return v.port.Contains(ctx.Outbound.Target.Port)
}
type NetworkMatcher struct {
@ -206,12 +179,11 @@ func NewNetworkMatcher(network []net.Network) NetworkMatcher {
return matcher
}
func (v NetworkMatcher) Apply(ctx context.Context) bool {
outbound := session.OutboundFromContext(ctx)
if outbound == nil || !outbound.Target.IsValid() {
func (v NetworkMatcher) Apply(ctx *Context) bool {
if ctx.Outbound == nil || !ctx.Outbound.Target.IsValid() {
return false
}
return v.list[int(outbound.Target.Network)]
return v.list[int(ctx.Outbound.Target.Network)]
}
type UserMatcher struct {
@ -230,13 +202,12 @@ func NewUserMatcher(users []string) *UserMatcher {
}
}
func (v *UserMatcher) Apply(ctx context.Context) bool {
inbound := session.InboundFromContext(ctx)
if inbound == nil {
func (v *UserMatcher) Apply(ctx *Context) bool {
if ctx.Inbound == nil {
return false
}
user := inbound.User
user := ctx.Inbound.User
if user == nil {
return false
}
@ -264,12 +235,11 @@ func NewInboundTagMatcher(tags []string) *InboundTagMatcher {
}
}
func (v *InboundTagMatcher) Apply(ctx context.Context) bool {
inbound := session.InboundFromContext(ctx)
if inbound == nil || len(inbound.Tag) == 0 {
func (v *InboundTagMatcher) Apply(ctx *Context) bool {
if ctx.Inbound == nil || len(ctx.Inbound.Tag) == 0 {
return false
}
tag := inbound.Tag
tag := ctx.Inbound.Tag
for _, t := range v.tags {
if t == tag {
return true
@ -296,14 +266,12 @@ func NewProtocolMatcher(protocols []string) *ProtocolMatcher {
}
}
func (m *ProtocolMatcher) Apply(ctx context.Context) bool {
content := session.ContentFromContext(ctx)
if content == nil {
func (m *ProtocolMatcher) Apply(ctx *Context) bool {
if ctx.Content == nil {
return false
}
protocol := content.Protocol
protocol := ctx.Content.Protocol
for _, p := range m.protocols {
if strings.HasPrefix(protocol, p) {
return true

View File

@ -1,7 +1,6 @@
package router_test
import (
"context"
"os"
"path/filepath"
"strconv"
@ -28,17 +27,17 @@ func init() {
common.Must(filesystem.CopyFile(platform.GetAssetLocation("geosite.dat"), filepath.Join(wd, "..", "..", "release", "config", "geosite.dat")))
}
func withOutbound(outbound *session.Outbound) context.Context {
return session.ContextWithOutbound(context.Background(), outbound)
func withOutbound(outbound *session.Outbound) *Context {
return &Context{Outbound: outbound}
}
func withInbound(inbound *session.Inbound) context.Context {
return session.ContextWithInbound(context.Background(), inbound)
func withInbound(inbound *session.Inbound) *Context {
return &Context{Inbound: inbound}
}
func TestRoutingRule(t *testing.T) {
type ruleTest struct {
input context.Context
input *Context
output bool
}
@ -89,7 +88,7 @@ func TestRoutingRule(t *testing.T) {
output: false,
},
{
input: context.Background(),
input: &Context{},
output: false,
},
},
@ -125,7 +124,7 @@ func TestRoutingRule(t *testing.T) {
output: true,
},
{
input: context.Background(),
input: &Context{},
output: false,
},
},
@ -165,7 +164,7 @@ func TestRoutingRule(t *testing.T) {
output: true,
},
{
input: context.Background(),
input: &Context{},
output: false,
},
},
@ -206,7 +205,7 @@ func TestRoutingRule(t *testing.T) {
output: false,
},
{
input: context.Background(),
input: &Context{},
output: false,
},
},
@ -217,7 +216,7 @@ func TestRoutingRule(t *testing.T) {
},
test: []ruleTest{
{
input: session.ContextWithContent(context.Background(), &session.Content{Protocol: (&http.SniffHeader{}).Protocol()}),
input: &Context{Content: &session.Content{Protocol: (&http.SniffHeader{}).Protocol()}},
output: true,
},
},

View File

@ -3,9 +3,7 @@
package router
import (
"context"
net "v2ray.com/core/common/net"
"v2ray.com/core/common/net"
"v2ray.com/core/features/outbound"
)
@ -61,7 +59,7 @@ func (r *Rule) GetTag() (string, error) {
return r.Tag, nil
}
func (r *Rule) Apply(ctx context.Context) bool {
func (r *Rule) Apply(ctx *Context) bool {
return r.Condition.Apply(ctx)
}

View File

@ -9,6 +9,7 @@ import (
"v2ray.com/core"
"v2ray.com/core/common"
"v2ray.com/core/common/net"
"v2ray.com/core/common/session"
"v2ray.com/core/features/dns"
"v2ray.com/core/features/outbound"
@ -85,44 +86,33 @@ func isDomainOutbound(outbound *session.Outbound) bool {
return outbound != nil && outbound.Target.IsValid() && outbound.Target.Address.Family().IsDomain()
}
func (r *Router) resolveIP(outbound *session.Outbound) error {
domain := outbound.Target.Address.Domain()
ips, err := r.dns.LookupIP(domain)
if err != nil {
return err
}
outbound.ResolvedIPs = ips
return nil
}
// PickRoute implements routing.Router.
func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) {
outbound := session.OutboundFromContext(ctx)
if r.domainStrategy == Config_IpOnDemand && isDomainOutbound(outbound) {
if err := r.resolveIP(outbound); err != nil {
newError("failed to resolve IP for domain").Base(err).WriteToLog(session.ExportIDToError(ctx))
sessionContext := &Context{
Inbound: session.InboundFromContext(ctx),
Outbound: session.OutboundFromContext(ctx),
Content: session.ContentFromContext(ctx),
}
if r.domainStrategy == Config_IpOnDemand {
sessionContext.dnsClient = r.dns
}
for _, rule := range r.rules {
if rule.Apply(ctx) {
if rule.Apply(sessionContext) {
return rule, nil
}
}
if r.domainStrategy != Config_IpIfNonMatch || !isDomainOutbound(outbound) {
if r.domainStrategy != Config_IpIfNonMatch || !isDomainOutbound(sessionContext.Outbound) {
return nil, common.ErrNoClue
}
if err := r.resolveIP(outbound); err != nil {
newError("failed to resolve IP for domain").Base(err).WriteToLog(session.ExportIDToError(ctx))
return nil, common.ErrNoClue
}
sessionContext.dnsClient = r.dns
// Try applying rules again if we have IPs.
for _, rule := range r.rules {
if rule.Apply(ctx) {
if rule.Apply(sessionContext) {
return rule, nil
}
}
@ -144,3 +134,37 @@ func (*Router) Close() error {
func (*Router) Type() interface{} {
return routing.RouterType()
}
type Context struct {
Inbound *session.Inbound
Outbound *session.Outbound
Content *session.Content
dnsClient dns.Client
}
func (c *Context) GetTargetIPs() []net.IP {
if c.Outbound == nil || !c.Outbound.Target.IsValid() {
return nil
}
if c.Outbound.Target.Address.Family().IsIP() {
return []net.IP{c.Outbound.Target.Address.IP()}
}
if len(c.Outbound.ResolvedIPs) > 0 {
return c.Outbound.ResolvedIPs
}
if c.dnsClient != nil {
domain := c.Outbound.Target.Address.Domain()
ips, err := c.dnsClient.LookupIP(domain)
if err == nil {
c.Outbound.ResolvedIPs = ips
return ips
}
newError("resolve ip for ", domain).Base(err).WriteToLog()
}
return nil
}

View File

@ -1,6 +1,7 @@
package router_test
import (
"context"
"testing"
"github.com/golang/mock/gomock"
@ -42,7 +43,7 @@ func TestSimpleRouter(t *testing.T) {
HandlerSelector: mockHs,
}))
ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
tag, err := r.PickRoute(ctx)
common.Must(err)
if tag != "test" {
@ -83,7 +84,7 @@ func TestSimpleBalancer(t *testing.T) {
HandlerSelector: mockHs,
}))
ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
tag, err := r.PickRoute(ctx)
common.Must(err)
if tag != "test" {
@ -118,7 +119,7 @@ func TestIPOnDemand(t *testing.T) {
r := new(Router)
common.Must(r.Init(config, mockDns, nil))
ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
tag, err := r.PickRoute(ctx)
common.Must(err)
if tag != "test" {
@ -153,7 +154,7 @@ func TestIPIfNonMatchDomain(t *testing.T) {
r := new(Router)
common.Must(r.Init(config, mockDns, nil))
ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
tag, err := r.PickRoute(ctx)
common.Must(err)
if tag != "test" {
@ -187,7 +188,7 @@ func TestIPIfNonMatchIP(t *testing.T) {
r := new(Router)
common.Must(r.Init(config, mockDns, nil))
ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)})
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)})
tag, err := r.PickRoute(ctx)
common.Must(err)
if tag != "test" {