mirror of
https://github.com/v2fly/v2ray-core.git
synced 2025-01-02 15:36:41 -05:00
extract all session context before checking conditions
This commit is contained in:
parent
cc513c1002
commit
0d31a68694
@ -3,16 +3,14 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"v2ray.com/core/common/net"
|
||||
"v2ray.com/core/common/session"
|
||||
"v2ray.com/core/common/strmatcher"
|
||||
)
|
||||
|
||||
type Condition interface {
|
||||
Apply(ctx context.Context) bool
|
||||
Apply(ctx *Context) bool
|
||||
}
|
||||
|
||||
type ConditionChan []Condition
|
||||
@ -27,7 +25,7 @@ func (v *ConditionChan) Add(cond Condition) *ConditionChan {
|
||||
return v
|
||||
}
|
||||
|
||||
func (v *ConditionChan) Apply(ctx context.Context) bool {
|
||||
func (v *ConditionChan) Apply(ctx *Context) bool {
|
||||
for _, cond := range *v {
|
||||
if !cond.Apply(ctx) {
|
||||
return false
|
||||
@ -84,46 +82,36 @@ func (m *DomainMatcher) ApplyDomain(domain string) bool {
|
||||
return m.matchers.Match(domain) > 0
|
||||
}
|
||||
|
||||
func (m *DomainMatcher) Apply(ctx context.Context) bool {
|
||||
outbound := session.OutboundFromContext(ctx)
|
||||
if outbound == nil || !outbound.Target.IsValid() {
|
||||
func (m *DomainMatcher) Apply(ctx *Context) bool {
|
||||
if ctx.Outbound == nil || !ctx.Outbound.Target.IsValid() {
|
||||
return false
|
||||
}
|
||||
dest := outbound.Target
|
||||
dest := ctx.Outbound.Target
|
||||
if !dest.Address.Family().IsDomain() {
|
||||
return false
|
||||
}
|
||||
return m.ApplyDomain(dest.Address.Domain())
|
||||
}
|
||||
|
||||
func sourceFromContext(ctx context.Context) net.Destination {
|
||||
inbound := session.InboundFromContext(ctx)
|
||||
if inbound == nil {
|
||||
return net.Destination{}
|
||||
}
|
||||
return inbound.Source
|
||||
}
|
||||
|
||||
func targetFromContent(ctx context.Context) net.Destination {
|
||||
outbound := session.OutboundFromContext(ctx)
|
||||
if outbound == nil {
|
||||
return net.Destination{}
|
||||
}
|
||||
return outbound.Target
|
||||
}
|
||||
|
||||
func resolvedIPFromContext(ctx context.Context) []net.IP {
|
||||
outbound := session.OutboundFromContext(ctx)
|
||||
if outbound == nil {
|
||||
func getIPsFromSource(ctx *Context) []net.IP {
|
||||
if ctx.Inbound == nil || !ctx.Inbound.Source.IsValid() {
|
||||
return nil
|
||||
}
|
||||
return outbound.ResolvedIPs
|
||||
dest := ctx.Inbound.Source
|
||||
if dest.Address.Family().IsDomain() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return []net.IP{dest.Address.IP()}
|
||||
}
|
||||
|
||||
func getIPsFromTarget(ctx *Context) []net.IP {
|
||||
return ctx.GetTargetIPs()
|
||||
}
|
||||
|
||||
type MultiGeoIPMatcher struct {
|
||||
matchers []*GeoIPMatcher
|
||||
destFunc func(context.Context) net.Destination
|
||||
resolvedIPFunc func(context.Context) []net.IP
|
||||
ipFunc func(*Context) []net.IP
|
||||
}
|
||||
|
||||
func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, error) {
|
||||
@ -141,30 +129,16 @@ func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, e
|
||||
}
|
||||
|
||||
if onSource {
|
||||
matcher.destFunc = sourceFromContext
|
||||
matcher.ipFunc = getIPsFromSource
|
||||
} else {
|
||||
matcher.destFunc = targetFromContent
|
||||
matcher.resolvedIPFunc = resolvedIPFromContext
|
||||
matcher.ipFunc = getIPsFromTarget
|
||||
}
|
||||
|
||||
return matcher, nil
|
||||
}
|
||||
|
||||
func (m *MultiGeoIPMatcher) Apply(ctx context.Context) bool {
|
||||
ips := make([]net.IP, 0, 4)
|
||||
|
||||
dest := m.destFunc(ctx)
|
||||
|
||||
if dest.IsValid() && dest.Address.Family().IsIP() {
|
||||
ips = append(ips, dest.Address.IP())
|
||||
}
|
||||
|
||||
if m.resolvedIPFunc != nil {
|
||||
rips := m.resolvedIPFunc(ctx)
|
||||
if len(rips) > 0 {
|
||||
ips = append(ips, rips...)
|
||||
}
|
||||
}
|
||||
func (m *MultiGeoIPMatcher) Apply(ctx *Context) bool {
|
||||
ips := m.ipFunc(ctx)
|
||||
|
||||
for _, ip := range ips {
|
||||
for _, matcher := range m.matchers {
|
||||
@ -186,12 +160,11 @@ func NewPortMatcher(list *net.PortList) *PortMatcher {
|
||||
}
|
||||
}
|
||||
|
||||
func (v *PortMatcher) Apply(ctx context.Context) bool {
|
||||
outbound := session.OutboundFromContext(ctx)
|
||||
if outbound == nil || !outbound.Target.IsValid() {
|
||||
func (v *PortMatcher) Apply(ctx *Context) bool {
|
||||
if ctx.Outbound == nil || !ctx.Outbound.Target.IsValid() {
|
||||
return false
|
||||
}
|
||||
return v.port.Contains(outbound.Target.Port)
|
||||
return v.port.Contains(ctx.Outbound.Target.Port)
|
||||
}
|
||||
|
||||
type NetworkMatcher struct {
|
||||
@ -206,12 +179,11 @@ func NewNetworkMatcher(network []net.Network) NetworkMatcher {
|
||||
return matcher
|
||||
}
|
||||
|
||||
func (v NetworkMatcher) Apply(ctx context.Context) bool {
|
||||
outbound := session.OutboundFromContext(ctx)
|
||||
if outbound == nil || !outbound.Target.IsValid() {
|
||||
func (v NetworkMatcher) Apply(ctx *Context) bool {
|
||||
if ctx.Outbound == nil || !ctx.Outbound.Target.IsValid() {
|
||||
return false
|
||||
}
|
||||
return v.list[int(outbound.Target.Network)]
|
||||
return v.list[int(ctx.Outbound.Target.Network)]
|
||||
}
|
||||
|
||||
type UserMatcher struct {
|
||||
@ -230,13 +202,12 @@ func NewUserMatcher(users []string) *UserMatcher {
|
||||
}
|
||||
}
|
||||
|
||||
func (v *UserMatcher) Apply(ctx context.Context) bool {
|
||||
inbound := session.InboundFromContext(ctx)
|
||||
if inbound == nil {
|
||||
func (v *UserMatcher) Apply(ctx *Context) bool {
|
||||
if ctx.Inbound == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
user := inbound.User
|
||||
user := ctx.Inbound.User
|
||||
if user == nil {
|
||||
return false
|
||||
}
|
||||
@ -264,12 +235,11 @@ func NewInboundTagMatcher(tags []string) *InboundTagMatcher {
|
||||
}
|
||||
}
|
||||
|
||||
func (v *InboundTagMatcher) Apply(ctx context.Context) bool {
|
||||
inbound := session.InboundFromContext(ctx)
|
||||
if inbound == nil || len(inbound.Tag) == 0 {
|
||||
func (v *InboundTagMatcher) Apply(ctx *Context) bool {
|
||||
if ctx.Inbound == nil || len(ctx.Inbound.Tag) == 0 {
|
||||
return false
|
||||
}
|
||||
tag := inbound.Tag
|
||||
tag := ctx.Inbound.Tag
|
||||
for _, t := range v.tags {
|
||||
if t == tag {
|
||||
return true
|
||||
@ -296,14 +266,12 @@ func NewProtocolMatcher(protocols []string) *ProtocolMatcher {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *ProtocolMatcher) Apply(ctx context.Context) bool {
|
||||
content := session.ContentFromContext(ctx)
|
||||
|
||||
if content == nil {
|
||||
func (m *ProtocolMatcher) Apply(ctx *Context) bool {
|
||||
if ctx.Content == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
protocol := content.Protocol
|
||||
protocol := ctx.Content.Protocol
|
||||
for _, p := range m.protocols {
|
||||
if strings.HasPrefix(protocol, p) {
|
||||
return true
|
||||
|
@ -1,7 +1,6 @@
|
||||
package router_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
@ -28,17 +27,17 @@ func init() {
|
||||
common.Must(filesystem.CopyFile(platform.GetAssetLocation("geosite.dat"), filepath.Join(wd, "..", "..", "release", "config", "geosite.dat")))
|
||||
}
|
||||
|
||||
func withOutbound(outbound *session.Outbound) context.Context {
|
||||
return session.ContextWithOutbound(context.Background(), outbound)
|
||||
func withOutbound(outbound *session.Outbound) *Context {
|
||||
return &Context{Outbound: outbound}
|
||||
}
|
||||
|
||||
func withInbound(inbound *session.Inbound) context.Context {
|
||||
return session.ContextWithInbound(context.Background(), inbound)
|
||||
func withInbound(inbound *session.Inbound) *Context {
|
||||
return &Context{Inbound: inbound}
|
||||
}
|
||||
|
||||
func TestRoutingRule(t *testing.T) {
|
||||
type ruleTest struct {
|
||||
input context.Context
|
||||
input *Context
|
||||
output bool
|
||||
}
|
||||
|
||||
@ -89,7 +88,7 @@ func TestRoutingRule(t *testing.T) {
|
||||
output: false,
|
||||
},
|
||||
{
|
||||
input: context.Background(),
|
||||
input: &Context{},
|
||||
output: false,
|
||||
},
|
||||
},
|
||||
@ -125,7 +124,7 @@ func TestRoutingRule(t *testing.T) {
|
||||
output: true,
|
||||
},
|
||||
{
|
||||
input: context.Background(),
|
||||
input: &Context{},
|
||||
output: false,
|
||||
},
|
||||
},
|
||||
@ -165,7 +164,7 @@ func TestRoutingRule(t *testing.T) {
|
||||
output: true,
|
||||
},
|
||||
{
|
||||
input: context.Background(),
|
||||
input: &Context{},
|
||||
output: false,
|
||||
},
|
||||
},
|
||||
@ -206,7 +205,7 @@ func TestRoutingRule(t *testing.T) {
|
||||
output: false,
|
||||
},
|
||||
{
|
||||
input: context.Background(),
|
||||
input: &Context{},
|
||||
output: false,
|
||||
},
|
||||
},
|
||||
@ -217,7 +216,7 @@ func TestRoutingRule(t *testing.T) {
|
||||
},
|
||||
test: []ruleTest{
|
||||
{
|
||||
input: session.ContextWithContent(context.Background(), &session.Content{Protocol: (&http.SniffHeader{}).Protocol()}),
|
||||
input: &Context{Content: &session.Content{Protocol: (&http.SniffHeader{}).Protocol()}},
|
||||
output: true,
|
||||
},
|
||||
},
|
||||
|
@ -3,9 +3,7 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
net "v2ray.com/core/common/net"
|
||||
"v2ray.com/core/common/net"
|
||||
"v2ray.com/core/features/outbound"
|
||||
)
|
||||
|
||||
@ -61,7 +59,7 @@ func (r *Rule) GetTag() (string, error) {
|
||||
return r.Tag, nil
|
||||
}
|
||||
|
||||
func (r *Rule) Apply(ctx context.Context) bool {
|
||||
func (r *Rule) Apply(ctx *Context) bool {
|
||||
return r.Condition.Apply(ctx)
|
||||
}
|
||||
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
|
||||
"v2ray.com/core"
|
||||
"v2ray.com/core/common"
|
||||
"v2ray.com/core/common/net"
|
||||
"v2ray.com/core/common/session"
|
||||
"v2ray.com/core/features/dns"
|
||||
"v2ray.com/core/features/outbound"
|
||||
@ -85,44 +86,33 @@ func isDomainOutbound(outbound *session.Outbound) bool {
|
||||
return outbound != nil && outbound.Target.IsValid() && outbound.Target.Address.Family().IsDomain()
|
||||
}
|
||||
|
||||
func (r *Router) resolveIP(outbound *session.Outbound) error {
|
||||
domain := outbound.Target.Address.Domain()
|
||||
ips, err := r.dns.LookupIP(domain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
outbound.ResolvedIPs = ips
|
||||
return nil
|
||||
}
|
||||
|
||||
// PickRoute implements routing.Router.
|
||||
func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) {
|
||||
outbound := session.OutboundFromContext(ctx)
|
||||
if r.domainStrategy == Config_IpOnDemand && isDomainOutbound(outbound) {
|
||||
if err := r.resolveIP(outbound); err != nil {
|
||||
newError("failed to resolve IP for domain").Base(err).WriteToLog(session.ExportIDToError(ctx))
|
||||
sessionContext := &Context{
|
||||
Inbound: session.InboundFromContext(ctx),
|
||||
Outbound: session.OutboundFromContext(ctx),
|
||||
Content: session.ContentFromContext(ctx),
|
||||
}
|
||||
|
||||
if r.domainStrategy == Config_IpOnDemand {
|
||||
sessionContext.dnsClient = r.dns
|
||||
}
|
||||
|
||||
for _, rule := range r.rules {
|
||||
if rule.Apply(ctx) {
|
||||
if rule.Apply(sessionContext) {
|
||||
return rule, nil
|
||||
}
|
||||
}
|
||||
|
||||
if r.domainStrategy != Config_IpIfNonMatch || !isDomainOutbound(outbound) {
|
||||
if r.domainStrategy != Config_IpIfNonMatch || !isDomainOutbound(sessionContext.Outbound) {
|
||||
return nil, common.ErrNoClue
|
||||
}
|
||||
|
||||
if err := r.resolveIP(outbound); err != nil {
|
||||
newError("failed to resolve IP for domain").Base(err).WriteToLog(session.ExportIDToError(ctx))
|
||||
return nil, common.ErrNoClue
|
||||
}
|
||||
sessionContext.dnsClient = r.dns
|
||||
|
||||
// Try applying rules again if we have IPs.
|
||||
for _, rule := range r.rules {
|
||||
if rule.Apply(ctx) {
|
||||
if rule.Apply(sessionContext) {
|
||||
return rule, nil
|
||||
}
|
||||
}
|
||||
@ -144,3 +134,37 @@ func (*Router) Close() error {
|
||||
func (*Router) Type() interface{} {
|
||||
return routing.RouterType()
|
||||
}
|
||||
|
||||
type Context struct {
|
||||
Inbound *session.Inbound
|
||||
Outbound *session.Outbound
|
||||
Content *session.Content
|
||||
|
||||
dnsClient dns.Client
|
||||
}
|
||||
|
||||
func (c *Context) GetTargetIPs() []net.IP {
|
||||
if c.Outbound == nil || !c.Outbound.Target.IsValid() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.Outbound.Target.Address.Family().IsIP() {
|
||||
return []net.IP{c.Outbound.Target.Address.IP()}
|
||||
}
|
||||
|
||||
if len(c.Outbound.ResolvedIPs) > 0 {
|
||||
return c.Outbound.ResolvedIPs
|
||||
}
|
||||
|
||||
if c.dnsClient != nil {
|
||||
domain := c.Outbound.Target.Address.Domain()
|
||||
ips, err := c.dnsClient.LookupIP(domain)
|
||||
if err == nil {
|
||||
c.Outbound.ResolvedIPs = ips
|
||||
return ips
|
||||
}
|
||||
newError("resolve ip for ", domain).Base(err).WriteToLog()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package router_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
@ -42,7 +43,7 @@ func TestSimpleRouter(t *testing.T) {
|
||||
HandlerSelector: mockHs,
|
||||
}))
|
||||
|
||||
ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
|
||||
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
|
||||
tag, err := r.PickRoute(ctx)
|
||||
common.Must(err)
|
||||
if tag != "test" {
|
||||
@ -83,7 +84,7 @@ func TestSimpleBalancer(t *testing.T) {
|
||||
HandlerSelector: mockHs,
|
||||
}))
|
||||
|
||||
ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
|
||||
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
|
||||
tag, err := r.PickRoute(ctx)
|
||||
common.Must(err)
|
||||
if tag != "test" {
|
||||
@ -118,7 +119,7 @@ func TestIPOnDemand(t *testing.T) {
|
||||
r := new(Router)
|
||||
common.Must(r.Init(config, mockDns, nil))
|
||||
|
||||
ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
|
||||
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
|
||||
tag, err := r.PickRoute(ctx)
|
||||
common.Must(err)
|
||||
if tag != "test" {
|
||||
@ -153,7 +154,7 @@ func TestIPIfNonMatchDomain(t *testing.T) {
|
||||
r := new(Router)
|
||||
common.Must(r.Init(config, mockDns, nil))
|
||||
|
||||
ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
|
||||
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
|
||||
tag, err := r.PickRoute(ctx)
|
||||
common.Must(err)
|
||||
if tag != "test" {
|
||||
@ -187,7 +188,7 @@ func TestIPIfNonMatchIP(t *testing.T) {
|
||||
r := new(Router)
|
||||
common.Must(r.Init(config, mockDns, nil))
|
||||
|
||||
ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)})
|
||||
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)})
|
||||
tag, err := r.PickRoute(ctx)
|
||||
common.Must(err)
|
||||
if tag != "test" {
|
||||
|
Loading…
Reference in New Issue
Block a user