1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2025-01-17 23:06:30 -05:00

use session.Outbound.ResolvedIPs

This commit is contained in:
Darien Raymond 2018-12-04 20:36:51 +01:00
parent 98d89aebc2
commit 82d562d1f0
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
3 changed files with 124 additions and 83 deletions

View File

@ -111,9 +111,18 @@ func targetFromContent(ctx context.Context) net.Destination {
return outbound.Target
}
func resolvedIPFromContext(ctx context.Context) []net.IP {
outbound := session.OutboundFromContext(ctx)
if outbound == nil {
return nil
}
return outbound.ResolvedIPs
}
type MultiGeoIPMatcher struct {
matchers []*GeoIPMatcher
destFunc func(context.Context) net.Destination
matchers []*GeoIPMatcher
destFunc func(context.Context) net.Destination
resolvedIPFunc func(context.Context) []net.IP
}
func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, error) {
@ -126,17 +135,18 @@ func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, e
matchers = append(matchers, matcher)
}
var destFunc func(context.Context) net.Destination
if onSource {
destFunc = sourceFromContext
} else {
destFunc = targetFromContent
matcher := &MultiGeoIPMatcher{
matchers: matchers,
}
return &MultiGeoIPMatcher{
matchers: matchers,
destFunc: destFunc,
}, nil
if onSource {
matcher.destFunc = sourceFromContext
} else {
matcher.destFunc = targetFromContent
matcher.resolvedIPFunc = resolvedIPFromContext
}
return matcher, nil
}
func (m *MultiGeoIPMatcher) Apply(ctx context.Context) bool {
@ -146,10 +156,12 @@ func (m *MultiGeoIPMatcher) Apply(ctx context.Context) bool {
if dest.IsValid() && dest.Address.Family().IsIP() {
ips = append(ips, dest.Address.IP())
} else if resolver, ok := ResolvedIPsFromContext(ctx); ok {
resolvedIPs := resolver.Resolve()
for _, rip := range resolvedIPs {
ips = append(ips, rip.IP())
}
if m.resolvedIPFunc != nil {
rips := m.resolvedIPFunc(ctx)
if len(rips) > 0 {
ips = append(ips, rips...)
}
}

View File

@ -7,32 +7,12 @@ import (
"v2ray.com/core"
"v2ray.com/core/common"
"v2ray.com/core/common/net"
"v2ray.com/core/common/session"
"v2ray.com/core/features/dns"
"v2ray.com/core/features/outbound"
"v2ray.com/core/features/routing"
)
type key uint32
const (
resolvedIPsKey key = iota
)
type IPResolver interface {
Resolve() []net.Address
}
func ContextWithResolveIPs(ctx context.Context, f IPResolver) context.Context {
return context.WithValue(ctx, resolvedIPsKey, f)
}
func ResolvedIPsFromContext(ctx context.Context) (IPResolver, bool) {
ips, ok := ctx.Value(resolvedIPsKey).(IPResolver)
return ips, ok
}
func init() {
common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
r := new(Router)
@ -91,34 +71,6 @@ func (r *Router) Init(config *Config, d dns.Client, ohm outbound.Manager) error
return nil
}
type ipResolver struct {
dns dns.Client
ip []net.Address
domain string
resolved bool
}
func (r *ipResolver) Resolve() []net.Address {
if r.resolved {
return r.ip
}
newError("looking for IP for domain: ", r.domain).WriteToLog()
r.resolved = true
ips, err := r.dns.LookupIP(r.domain)
if err != nil {
newError("failed to get IP address").Base(err).WriteToLog()
}
if len(ips) == 0 {
return nil
}
r.ip = make([]net.Address, len(ips))
for i, ip := range ips {
r.ip[i] = net.IPAddress(ip)
}
return r.ip
}
func (r *Router) PickRoute(ctx context.Context) (string, error) {
rule, err := r.pickRouteInternal(ctx)
if err != nil {
@ -127,17 +79,27 @@ func (r *Router) PickRoute(ctx context.Context) (string, error) {
return rule.GetTag()
}
// PickRoute implements routing.Router.
func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) {
resolver := &ipResolver{
dns: r.dns,
func isDomainOutbound(outbound *session.Outbound) bool {
return outbound != nil && outbound.Target.IsValid() && outbound.Target.Address.Family().IsDomain()
}
func (r *Router) resolveIP(outbound *session.Outbound) error {
domain := outbound.Target.Address.Domain()
ips, err := r.dns.LookupIP(domain)
if err != nil {
return err
}
outbound.ResolvedIPs = ips
return nil
}
// PickRoute implements routing.Router.
func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) {
outbound := session.OutboundFromContext(ctx)
if r.domainStrategy == Config_IpOnDemand {
if outbound != nil && outbound.Target.IsValid() && outbound.Target.Address.Family().IsDomain() {
resolver.domain = outbound.Target.Address.Domain()
ctx = ContextWithResolveIPs(ctx, resolver)
if r.domainStrategy == Config_IpOnDemand && isDomainOutbound(outbound) {
if err := r.resolveIP(outbound); err != nil {
newError("failed to resolve IP for domain").Base(err).WriteToLog(session.ExportIDToError(ctx))
}
}
@ -147,21 +109,19 @@ func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) {
}
}
if outbound == nil || !outbound.Target.IsValid() {
if r.domainStrategy != Config_IpIfNonMatch || !isDomainOutbound(outbound) {
return nil, common.ErrNoClue
}
dest := outbound.Target
if r.domainStrategy == Config_IpIfNonMatch && dest.Address.Family().IsDomain() {
resolver.domain = dest.Address.Domain()
ips := resolver.Resolve()
if len(ips) > 0 {
ctx = ContextWithResolveIPs(ctx, resolver)
for _, rule := range r.rules {
if rule.Apply(ctx) {
return rule, nil
}
}
if err := r.resolveIP(outbound); err != nil {
newError("failed to resolve IP for domain").Base(err).WriteToLog(session.ExportIDToError(ctx))
return nil, common.ErrNoClue
}
// Try applying rules again if we have IPs.
for _, rule := range r.rules {
if rule.Apply(ctx) {
return rule, nil
}
}

View File

@ -125,3 +125,72 @@ func TestIPOnDemand(t *testing.T) {
t.Error("expect tag 'test', bug actually ", tag)
}
}
func TestIPIfNonMatchDomain(t *testing.T) {
config := &Config{
DomainStrategy: Config_IpIfNonMatch,
Rule: []*RoutingRule{
{
TargetTag: &RoutingRule_Tag{
Tag: "test",
},
Cidr: []*CIDR{
{
Ip: []byte{192, 168, 0, 0},
Prefix: 16,
},
},
},
},
}
mockCtl := gomock.NewController(t)
defer mockCtl.Finish()
mockDns := mocks.NewDNSClient(mockCtl)
mockDns.EXPECT().LookupIP(gomock.Eq("v2ray.com")).Return([]net.IP{{192, 168, 0, 1}}, nil).AnyTimes()
r := new(Router)
common.Must(r.Init(config, mockDns, nil))
ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
tag, err := r.PickRoute(ctx)
common.Must(err)
if tag != "test" {
t.Error("expect tag 'test', bug actually ", tag)
}
}
func TestIPIfNonMatchIP(t *testing.T) {
config := &Config{
DomainStrategy: Config_IpIfNonMatch,
Rule: []*RoutingRule{
{
TargetTag: &RoutingRule_Tag{
Tag: "test",
},
Cidr: []*CIDR{
{
Ip: []byte{127, 0, 0, 0},
Prefix: 8,
},
},
},
},
}
mockCtl := gomock.NewController(t)
defer mockCtl.Finish()
mockDns := mocks.NewDNSClient(mockCtl)
r := new(Router)
common.Must(r.Init(config, mockDns, nil))
ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)})
tag, err := r.PickRoute(ctx)
common.Must(err)
if tag != "test" {
t.Error("expect tag 'test', bug actually ", tag)
}
}