mirror of
https://github.com/v2fly/v2ray-core.git
synced 2025-01-17 23:06:30 -05:00
use session.Outbound.ResolvedIPs
This commit is contained in:
parent
98d89aebc2
commit
82d562d1f0
@ -111,9 +111,18 @@ func targetFromContent(ctx context.Context) net.Destination {
|
||||
return outbound.Target
|
||||
}
|
||||
|
||||
func resolvedIPFromContext(ctx context.Context) []net.IP {
|
||||
outbound := session.OutboundFromContext(ctx)
|
||||
if outbound == nil {
|
||||
return nil
|
||||
}
|
||||
return outbound.ResolvedIPs
|
||||
}
|
||||
|
||||
type MultiGeoIPMatcher struct {
|
||||
matchers []*GeoIPMatcher
|
||||
destFunc func(context.Context) net.Destination
|
||||
matchers []*GeoIPMatcher
|
||||
destFunc func(context.Context) net.Destination
|
||||
resolvedIPFunc func(context.Context) []net.IP
|
||||
}
|
||||
|
||||
func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, error) {
|
||||
@ -126,17 +135,18 @@ func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, e
|
||||
matchers = append(matchers, matcher)
|
||||
}
|
||||
|
||||
var destFunc func(context.Context) net.Destination
|
||||
if onSource {
|
||||
destFunc = sourceFromContext
|
||||
} else {
|
||||
destFunc = targetFromContent
|
||||
matcher := &MultiGeoIPMatcher{
|
||||
matchers: matchers,
|
||||
}
|
||||
|
||||
return &MultiGeoIPMatcher{
|
||||
matchers: matchers,
|
||||
destFunc: destFunc,
|
||||
}, nil
|
||||
if onSource {
|
||||
matcher.destFunc = sourceFromContext
|
||||
} else {
|
||||
matcher.destFunc = targetFromContent
|
||||
matcher.resolvedIPFunc = resolvedIPFromContext
|
||||
}
|
||||
|
||||
return matcher, nil
|
||||
}
|
||||
|
||||
func (m *MultiGeoIPMatcher) Apply(ctx context.Context) bool {
|
||||
@ -146,10 +156,12 @@ func (m *MultiGeoIPMatcher) Apply(ctx context.Context) bool {
|
||||
|
||||
if dest.IsValid() && dest.Address.Family().IsIP() {
|
||||
ips = append(ips, dest.Address.IP())
|
||||
} else if resolver, ok := ResolvedIPsFromContext(ctx); ok {
|
||||
resolvedIPs := resolver.Resolve()
|
||||
for _, rip := range resolvedIPs {
|
||||
ips = append(ips, rip.IP())
|
||||
}
|
||||
|
||||
if m.resolvedIPFunc != nil {
|
||||
rips := m.resolvedIPFunc(ctx)
|
||||
if len(rips) > 0 {
|
||||
ips = append(ips, rips...)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -7,32 +7,12 @@ import (
|
||||
|
||||
"v2ray.com/core"
|
||||
"v2ray.com/core/common"
|
||||
"v2ray.com/core/common/net"
|
||||
"v2ray.com/core/common/session"
|
||||
"v2ray.com/core/features/dns"
|
||||
"v2ray.com/core/features/outbound"
|
||||
"v2ray.com/core/features/routing"
|
||||
)
|
||||
|
||||
type key uint32
|
||||
|
||||
const (
|
||||
resolvedIPsKey key = iota
|
||||
)
|
||||
|
||||
type IPResolver interface {
|
||||
Resolve() []net.Address
|
||||
}
|
||||
|
||||
func ContextWithResolveIPs(ctx context.Context, f IPResolver) context.Context {
|
||||
return context.WithValue(ctx, resolvedIPsKey, f)
|
||||
}
|
||||
|
||||
func ResolvedIPsFromContext(ctx context.Context) (IPResolver, bool) {
|
||||
ips, ok := ctx.Value(resolvedIPsKey).(IPResolver)
|
||||
return ips, ok
|
||||
}
|
||||
|
||||
func init() {
|
||||
common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
|
||||
r := new(Router)
|
||||
@ -91,34 +71,6 @@ func (r *Router) Init(config *Config, d dns.Client, ohm outbound.Manager) error
|
||||
return nil
|
||||
}
|
||||
|
||||
type ipResolver struct {
|
||||
dns dns.Client
|
||||
ip []net.Address
|
||||
domain string
|
||||
resolved bool
|
||||
}
|
||||
|
||||
func (r *ipResolver) Resolve() []net.Address {
|
||||
if r.resolved {
|
||||
return r.ip
|
||||
}
|
||||
|
||||
newError("looking for IP for domain: ", r.domain).WriteToLog()
|
||||
r.resolved = true
|
||||
ips, err := r.dns.LookupIP(r.domain)
|
||||
if err != nil {
|
||||
newError("failed to get IP address").Base(err).WriteToLog()
|
||||
}
|
||||
if len(ips) == 0 {
|
||||
return nil
|
||||
}
|
||||
r.ip = make([]net.Address, len(ips))
|
||||
for i, ip := range ips {
|
||||
r.ip[i] = net.IPAddress(ip)
|
||||
}
|
||||
return r.ip
|
||||
}
|
||||
|
||||
func (r *Router) PickRoute(ctx context.Context) (string, error) {
|
||||
rule, err := r.pickRouteInternal(ctx)
|
||||
if err != nil {
|
||||
@ -127,17 +79,27 @@ func (r *Router) PickRoute(ctx context.Context) (string, error) {
|
||||
return rule.GetTag()
|
||||
}
|
||||
|
||||
// PickRoute implements routing.Router.
|
||||
func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) {
|
||||
resolver := &ipResolver{
|
||||
dns: r.dns,
|
||||
func isDomainOutbound(outbound *session.Outbound) bool {
|
||||
return outbound != nil && outbound.Target.IsValid() && outbound.Target.Address.Family().IsDomain()
|
||||
}
|
||||
|
||||
func (r *Router) resolveIP(outbound *session.Outbound) error {
|
||||
domain := outbound.Target.Address.Domain()
|
||||
ips, err := r.dns.LookupIP(domain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
outbound.ResolvedIPs = ips
|
||||
return nil
|
||||
}
|
||||
|
||||
// PickRoute implements routing.Router.
|
||||
func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) {
|
||||
outbound := session.OutboundFromContext(ctx)
|
||||
if r.domainStrategy == Config_IpOnDemand {
|
||||
if outbound != nil && outbound.Target.IsValid() && outbound.Target.Address.Family().IsDomain() {
|
||||
resolver.domain = outbound.Target.Address.Domain()
|
||||
ctx = ContextWithResolveIPs(ctx, resolver)
|
||||
if r.domainStrategy == Config_IpOnDemand && isDomainOutbound(outbound) {
|
||||
if err := r.resolveIP(outbound); err != nil {
|
||||
newError("failed to resolve IP for domain").Base(err).WriteToLog(session.ExportIDToError(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
@ -147,21 +109,19 @@ func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if outbound == nil || !outbound.Target.IsValid() {
|
||||
if r.domainStrategy != Config_IpIfNonMatch || !isDomainOutbound(outbound) {
|
||||
return nil, common.ErrNoClue
|
||||
}
|
||||
|
||||
dest := outbound.Target
|
||||
if r.domainStrategy == Config_IpIfNonMatch && dest.Address.Family().IsDomain() {
|
||||
resolver.domain = dest.Address.Domain()
|
||||
ips := resolver.Resolve()
|
||||
if len(ips) > 0 {
|
||||
ctx = ContextWithResolveIPs(ctx, resolver)
|
||||
for _, rule := range r.rules {
|
||||
if rule.Apply(ctx) {
|
||||
return rule, nil
|
||||
}
|
||||
}
|
||||
if err := r.resolveIP(outbound); err != nil {
|
||||
newError("failed to resolve IP for domain").Base(err).WriteToLog(session.ExportIDToError(ctx))
|
||||
return nil, common.ErrNoClue
|
||||
}
|
||||
|
||||
// Try applying rules again if we have IPs.
|
||||
for _, rule := range r.rules {
|
||||
if rule.Apply(ctx) {
|
||||
return rule, nil
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -125,3 +125,72 @@ func TestIPOnDemand(t *testing.T) {
|
||||
t.Error("expect tag 'test', bug actually ", tag)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPIfNonMatchDomain(t *testing.T) {
|
||||
config := &Config{
|
||||
DomainStrategy: Config_IpIfNonMatch,
|
||||
Rule: []*RoutingRule{
|
||||
{
|
||||
TargetTag: &RoutingRule_Tag{
|
||||
Tag: "test",
|
||||
},
|
||||
Cidr: []*CIDR{
|
||||
{
|
||||
Ip: []byte{192, 168, 0, 0},
|
||||
Prefix: 16,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mockCtl := gomock.NewController(t)
|
||||
defer mockCtl.Finish()
|
||||
|
||||
mockDns := mocks.NewDNSClient(mockCtl)
|
||||
mockDns.EXPECT().LookupIP(gomock.Eq("v2ray.com")).Return([]net.IP{{192, 168, 0, 1}}, nil).AnyTimes()
|
||||
|
||||
r := new(Router)
|
||||
common.Must(r.Init(config, mockDns, nil))
|
||||
|
||||
ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
|
||||
tag, err := r.PickRoute(ctx)
|
||||
common.Must(err)
|
||||
if tag != "test" {
|
||||
t.Error("expect tag 'test', bug actually ", tag)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPIfNonMatchIP(t *testing.T) {
|
||||
config := &Config{
|
||||
DomainStrategy: Config_IpIfNonMatch,
|
||||
Rule: []*RoutingRule{
|
||||
{
|
||||
TargetTag: &RoutingRule_Tag{
|
||||
Tag: "test",
|
||||
},
|
||||
Cidr: []*CIDR{
|
||||
{
|
||||
Ip: []byte{127, 0, 0, 0},
|
||||
Prefix: 8,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mockCtl := gomock.NewController(t)
|
||||
defer mockCtl.Finish()
|
||||
|
||||
mockDns := mocks.NewDNSClient(mockCtl)
|
||||
|
||||
r := new(Router)
|
||||
common.Must(r.Init(config, mockDns, nil))
|
||||
|
||||
ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)})
|
||||
tag, err := r.PickRoute(ctx)
|
||||
common.Must(err)
|
||||
if tag != "test" {
|
||||
t.Error("expect tag 'test', bug actually ", tag)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user