diff --git a/app/router/condition.go b/app/router/condition.go index 1f956c8da..38d4364b3 100644 --- a/app/router/condition.go +++ b/app/router/condition.go @@ -120,22 +120,6 @@ func (m *DomainMatcher) Apply(ctx context.Context) bool { return m.ApplyDomain(dest.Address.Domain()) } -type CIDRMatcher struct { - cidr *net.IPNet - onSource bool -} - -func NewCIDRMatcher(ip []byte, mask uint32, onSource bool) (*CIDRMatcher, error) { - cidr := &net.IPNet{ - IP: net.IP(ip), - Mask: net.CIDRMask(int(mask), len(ip)*8), - } - return &CIDRMatcher{ - cidr: cidr, - onSource: onSource, - }, nil -} - func sourceFromContext(ctx context.Context) net.Destination { inbound := session.InboundFromContext(ctx) if inbound == nil { @@ -152,80 +136,6 @@ func targetFromContent(ctx context.Context) net.Destination { return outbound.Target } -func (v *CIDRMatcher) Apply(ctx context.Context) bool { - ips := make([]net.IP, 0, 4) - if resolver, ok := ResolvedIPsFromContext(ctx); ok { - resolvedIPs := resolver.Resolve() - for _, rip := range resolvedIPs { - if !rip.Family().IsIPv6() { - continue - } - ips = append(ips, rip.IP()) - } - } - - var dest net.Destination - if v.onSource { - dest = sourceFromContext(ctx) - } else { - dest = targetFromContent(ctx) - } - - if dest.IsValid() && dest.Address.Family().IsIPv6() { - ips = append(ips, dest.Address.IP()) - } - - for _, ip := range ips { - if v.cidr.Contains(ip) { - return true - } - } - return false -} - -type IPv4Matcher struct { - ipv4net *net.IPNetTable - onSource bool -} - -func NewIPv4Matcher(ipnet *net.IPNetTable, onSource bool) *IPv4Matcher { - return &IPv4Matcher{ - ipv4net: ipnet, - onSource: onSource, - } -} - -func (v *IPv4Matcher) Apply(ctx context.Context) bool { - ips := make([]net.IP, 0, 4) - if resolver, ok := ResolvedIPsFromContext(ctx); ok { - resolvedIPs := resolver.Resolve() - for _, rip := range resolvedIPs { - if !rip.Family().IsIPv4() { - continue - } - ips = append(ips, rip.IP()) - } - } - - var dest net.Destination - if v.onSource { - dest = sourceFromContext(ctx) - } else { - dest = targetFromContent(ctx) - } - - if dest.IsValid() && dest.Address.Family().IsIPv4() { - ips = append(ips, dest.Address.IP()) - } - - for _, ip := range ips { - if v.ipv4net.Contains(ip) { - return true - } - } - return false -} - type MultiGeoIPMatcher struct { matchers []*GeoIPMatcher onSource bool diff --git a/app/router/config.go b/app/router/config.go index 23d268eeb..acdada0cf 100644 --- a/app/router/config.go +++ b/app/router/config.go @@ -2,8 +2,6 @@ package router import ( "context" - - "v2ray.com/core/common/net" ) // CIDRList is an alias of []*CIDR to provide sort.Interface. @@ -54,40 +52,6 @@ func (r *Rule) Apply(ctx context.Context) bool { return r.Condition.Apply(ctx) } -func cidrToCondition(cidr []*CIDR, source bool) (Condition, error) { - ipv4Net := net.NewIPNetTable() - ipv6Cond := NewAnyCondition() - hasIpv6 := false - - for _, ip := range cidr { - switch len(ip.Ip) { - case net.IPv4len: - ipv4Net.AddIP(ip.Ip, byte(ip.Prefix)) - case net.IPv6len: - hasIpv6 = true - matcher, err := NewCIDRMatcher(ip.Ip, ip.Prefix, source) - if err != nil { - return nil, err - } - ipv6Cond.Add(matcher) - default: - return nil, newError("invalid IP length").AtWarning() - } - } - - switch { - case !ipv4Net.IsEmpty() && hasIpv6: - cond := NewAnyCondition() - cond.Add(NewIPv4Matcher(ipv4Net, source)) - cond.Add(ipv6Cond) - return cond, nil - case !ipv4Net.IsEmpty(): - return NewIPv4Matcher(ipv4Net, source), nil - default: - return ipv6Cond, nil - } -} - func (rr *RoutingRule) BuildCondition() (Condition, error) { conds := NewConditionChan() @@ -122,7 +86,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) { } conds.Add(cond) } else if len(rr.Cidr) > 0 { - cond, err := cidrToCondition(rr.Cidr, false) + cond, err := NewMultiGeoIPMatcher([]*GeoIP{{Cidr: rr.Cidr}}, false) if err != nil { return nil, err } @@ -136,7 +100,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) { } conds.Add(cond) } else if len(rr.SourceCidr) > 0 { - cond, err := cidrToCondition(rr.SourceCidr, true) + cond, err := NewMultiGeoIPMatcher([]*GeoIP{{Cidr: rr.SourceCidr}}, true) if err != nil { return nil, err } diff --git a/common/net/ipnet.go b/common/net/ipnet.go deleted file mode 100644 index c30f2e882..000000000 --- a/common/net/ipnet.go +++ /dev/null @@ -1,83 +0,0 @@ -package net - -import ( - "math/bits" - "net" -) - -type IPNetTable struct { - cache map[uint32]byte -} - -func NewIPNetTable() *IPNetTable { - return &IPNetTable{ - cache: make(map[uint32]byte, 1024), - } -} - -func ipToUint32(ip IP) uint32 { - value := uint32(0) - for _, b := range []byte(ip) { - value <<= 8 - value += uint32(b) - } - return value -} - -func ipMaskToByte(mask net.IPMask) byte { - value := byte(0) - for _, b := range []byte(mask) { - value += byte(bits.OnesCount8(b)) - } - return value -} - -func (n *IPNetTable) Add(ipNet *net.IPNet) { - ipv4 := ipNet.IP.To4() - if ipv4 == nil { - // For now, we don't support IPv6 - return - } - mask := ipMaskToByte(ipNet.Mask) - n.AddIP(ipv4, mask) -} - -func (n *IPNetTable) AddIP(ip []byte, mask byte) { - k := ipToUint32(ip) - k = (k >> (32 - mask)) << (32 - mask) // normalize ip - existing, found := n.cache[k] - if !found || existing > mask { - n.cache[k] = mask - } -} - -func (n *IPNetTable) Contains(ip net.IP) bool { - ipv4 := ip.To4() - if ipv4 == nil { - return false - } - originalValue := ipToUint32(ipv4) - - if entry, found := n.cache[originalValue]; found { - if entry == 32 { - return true - } - } - - mask := uint32(0) - for maskbit := byte(1); maskbit <= 32; maskbit++ { - mask += 1 << uint32(32-maskbit) - - maskedValue := originalValue & mask - if entry, found := n.cache[maskedValue]; found { - if entry == maskbit { - return true - } - } - } - return false -} - -func (n *IPNetTable) IsEmpty() bool { - return len(n.cache) == 0 -} diff --git a/common/net/ipnet_test.go b/common/net/ipnet_test.go deleted file mode 100644 index 77bab96b7..000000000 --- a/common/net/ipnet_test.go +++ /dev/null @@ -1,133 +0,0 @@ -package net_test - -import ( - "net" - "os" - "path/filepath" - "testing" - - proto "github.com/golang/protobuf/proto" - "v2ray.com/core/app/router" - "v2ray.com/core/common/platform" - - "v2ray.com/ext/sysio" - - "v2ray.com/core/common" - . "v2ray.com/core/common/net" - . "v2ray.com/ext/assert" -) - -func parseCIDR(str string) *net.IPNet { - _, ipNet, err := net.ParseCIDR(str) - common.Must(err) - return ipNet -} - -func TestIPNet(t *testing.T) { - assert := With(t) - - ipNet := NewIPNetTable() - ipNet.Add(parseCIDR(("0.0.0.0/8"))) - ipNet.Add(parseCIDR(("10.0.0.0/8"))) - ipNet.Add(parseCIDR(("100.64.0.0/10"))) - ipNet.Add(parseCIDR(("127.0.0.0/8"))) - ipNet.Add(parseCIDR(("169.254.0.0/16"))) - ipNet.Add(parseCIDR(("172.16.0.0/12"))) - ipNet.Add(parseCIDR(("192.0.0.0/24"))) - ipNet.Add(parseCIDR(("192.0.2.0/24"))) - ipNet.Add(parseCIDR(("192.168.0.0/16"))) - ipNet.Add(parseCIDR(("198.18.0.0/15"))) - ipNet.Add(parseCIDR(("198.51.100.0/24"))) - ipNet.Add(parseCIDR(("203.0.113.0/24"))) - ipNet.Add(parseCIDR(("8.8.8.8/32"))) - ipNet.AddIP(net.ParseIP("91.108.4.0"), 16) - assert(ipNet.Contains(ParseIP("192.168.1.1")), IsTrue) - assert(ipNet.Contains(ParseIP("192.0.0.0")), IsTrue) - assert(ipNet.Contains(ParseIP("192.0.1.0")), IsFalse) - assert(ipNet.Contains(ParseIP("0.1.0.0")), IsTrue) - assert(ipNet.Contains(ParseIP("1.0.0.1")), IsFalse) - assert(ipNet.Contains(ParseIP("8.8.8.7")), IsFalse) - assert(ipNet.Contains(ParseIP("8.8.8.8")), IsTrue) - assert(ipNet.Contains(ParseIP("2001:cdba::3257:9652")), IsFalse) - assert(ipNet.Contains(ParseIP("91.108.255.254")), IsTrue) -} - -func TestGeoIPCN(t *testing.T) { - assert := With(t) - common.Must(sysio.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "release", "config", "geoip.dat"))) - - ips, err := loadGeoIP("CN") - common.Must(err) - - ipNet := NewIPNetTable() - for _, ip := range ips { - ipNet.AddIP(ip.Ip, byte(ip.Prefix)) - } - - assert(ipNet.Contains([]byte{8, 8, 8, 8}), IsFalse) -} - -func loadGeoIP(country string) ([]*router.CIDR, error) { - geoipBytes, err := sysio.ReadAsset("geoip.dat") - if err != nil { - return nil, err - } - var geoipList router.GeoIPList - if err := proto.Unmarshal(geoipBytes, &geoipList); err != nil { - return nil, err - } - - for _, geoip := range geoipList.Entry { - if geoip.CountryCode == country { - return geoip.Cidr, nil - } - } - - panic("country not found: " + country) -} - -func BenchmarkIPNetQuery(b *testing.B) { - common.Must(sysio.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "release", "config", "geoip.dat"))) - - ips, err := loadGeoIP("CN") - common.Must(err) - - ipNet := NewIPNetTable() - for _, ip := range ips { - ipNet.AddIP(ip.Ip, byte(ip.Prefix)) - } - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - ipNet.Contains([]byte{8, 8, 8, 8}) - } -} - -func BenchmarkCIDRQuery(b *testing.B) { - common.Must(sysio.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "release", "config", "geoip.dat"))) - - ips, err := loadGeoIP("CN") - common.Must(err) - - ipNet := make([]*net.IPNet, 0, 1024) - for _, ip := range ips { - if len(ip.Ip) != 4 { - continue - } - ipNet = append(ipNet, &net.IPNet{ - IP: net.IP(ip.Ip), - Mask: net.CIDRMask(int(ip.Prefix), 32), - }) - } - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - for _, n := range ipNet { - if n.Contains([]byte{8, 8, 8, 8}) { - break - } - } - } -}