diff --git a/app/router/condition.go b/app/router/condition.go index cf188686b..54fe9264b 100644 --- a/app/router/condition.go +++ b/app/router/condition.go @@ -144,7 +144,7 @@ func (m *MultiGeoIPMatcher) Apply(ctx context.Context) bool { dest := m.destFunc(ctx) - if dest.IsValid() && (dest.Address.Family().IsIPv4() || dest.Address.Family().IsIPv6()) { + if dest.IsValid() && dest.Address.Family().IsIP() { ips = append(ips, dest.Address.IP()) } else if resolver, ok := ResolvedIPsFromContext(ctx); ok { resolvedIPs := resolver.Resolve() diff --git a/common/net/address.go b/common/net/address.go index da15b803b..1794ce7ff 100644 --- a/common/net/address.go +++ b/common/net/address.go @@ -35,16 +35,6 @@ const ( AddressFamilyDomain = AddressFamily(2) ) -// Either returns true if current AddressFamily matches any of the AddressFamilies provided. -func (af AddressFamily) Either(fs ...AddressFamily) bool { - for _, f := range fs { - if af == f { - return true - } - } - return false -} - // IsIPv4 returns true if current AddressFamily is IPv4. func (af AddressFamily) IsIPv4() bool { return af == AddressFamilyIPv4 @@ -55,6 +45,11 @@ func (af AddressFamily) IsIPv6() bool { return af == AddressFamilyIPv6 } +// IsIP returns true if current AddressFamily is IPv6 or IPv4. +func (af AddressFamily) IsIP() bool { + return af == AddressFamilyIPv4 || af == AddressFamilyIPv6 +} + // IsDomain returns true if current AddressFamily is Domain. func (af AddressFamily) IsDomain() bool { return af == AddressFamilyDomain diff --git a/common/protocol/address.go b/common/protocol/address.go index a3c628d4d..5bd974d49 100644 --- a/common/protocol/address.go +++ b/common/protocol/address.go @@ -213,7 +213,7 @@ func (p *addressParser) readAddress(b *buf.Buffer, reader io.Reader) (net.Addres domain := string(b.BytesFrom(-domainLength)) if maybeIPPrefix(domain[0]) { addr := net.ParseAddress(domain) - if addr.Family().IsIPv4() || addr.Family().IsIPv6() { + if addr.Family().IsIP() { return addr, nil } } diff --git a/transport/internet/websocket/hub.go b/transport/internet/websocket/hub.go index bbbbe06f5..5dcb12364 100644 --- a/transport/internet/websocket/hub.go +++ b/transport/internet/websocket/hub.go @@ -41,7 +41,7 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req forwardedAddrs := http_proto.ParseXForwardedFor(request.Header) remoteAddr := conn.RemoteAddr() - if len(forwardedAddrs) > 0 && forwardedAddrs[0].Family().Either(net.AddressFamilyIPv4, net.AddressFamilyIPv6) { + if len(forwardedAddrs) > 0 && forwardedAddrs[0].Family().IsIP() { remoteAddr.(*net.TCPAddr).IP = forwardedAddrs[0].IP() }