From 21a15bbf7441dd35c2dd53b4adad1c5caaacd6bf Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Tue, 3 Jan 2017 15:16:48 +0100 Subject: [PATCH] registerable dialer and listener --- transport/internet/dialer.go | 43 +++++++++++++------------- transport/internet/dialer_test.go | 2 +- transport/internet/kcp/dialer.go | 5 +-- transport/internet/kcp/listener.go | 3 +- transport/internet/tcp/dialer.go | 5 +-- transport/internet/tcp/hub.go | 3 +- transport/internet/tcp_hub.go | 32 +++++++++---------- transport/internet/udp/connection.go | 22 +++++++------ transport/internet/websocket/dialer.go | 5 +-- transport/internet/websocket/hub.go | 3 +- 10 files changed, 63 insertions(+), 60 deletions(-) diff --git a/transport/internet/dialer.go b/transport/internet/dialer.go index 7ffff0ad1..3bfcec1e2 100644 --- a/transport/internet/dialer.go +++ b/transport/internet/dialer.go @@ -20,42 +20,41 @@ type DialerOptions struct { type Dialer func(src v2net.Address, dest v2net.Destination, options DialerOptions) (Connection, error) var ( - TCPDialer Dialer - KCPDialer Dialer - UDPDialer Dialer - WSDialer Dialer + networkDialerCache = make(map[v2net.Network]Dialer) + ProxyDialer Dialer ) +func RegisterNetworkDialer(network v2net.Network, dialer Dialer) error { + if _, found := networkDialerCache[network]; found { + return errors.New("Internet|Dialer: ", network, " dialer already registered.") + } + networkDialerCache[network] = dialer + return nil +} + func Dial(src v2net.Address, dest v2net.Destination, options DialerOptions) (Connection, error) { if options.Proxy.HasTag() && ProxyDialer != nil { log.Info("Internet: Proxying outbound connection through: ", options.Proxy.Tag) return ProxyDialer(src, dest, options) } - var connection Connection - var err error if dest.Network == v2net.Network_TCP { - switch options.Stream.Network { - case v2net.Network_TCP: - connection, err = TCPDialer(src, dest, options) - case v2net.Network_KCP: - connection, err = KCPDialer(src, dest, options) - case v2net.Network_WebSocket: - connection, err = WSDialer(src, dest, options) - default: - return nil, ErrUnsupportedStreamType + dialer := networkDialerCache[options.Stream.Network] + if dialer == nil { + return nil, errors.New("Internet|Dialer: ", options.Stream.Network, " dialer not registered.") } - if err != nil { - return nil, err - } - - return connection, nil + return dialer(src, dest, options) } - return UDPDialer(src, dest, options) + udpDialer := networkDialerCache[v2net.Network_UDP] + if udpDialer == nil { + return nil, errors.New("Internet|Dialer: UDP dialer not registered.") + } + return udpDialer(src, dest, options) } -func DialToDest(src v2net.Address, dest v2net.Destination) (net.Conn, error) { +// DialSystem calls system dialer to create a network connection. +func DialSystem(src v2net.Address, dest v2net.Destination) (net.Conn, error) { return effectiveSystemDialer.Dial(src, dest) } diff --git a/transport/internet/dialer_test.go b/transport/internet/dialer_test.go index c58c6e94a..56f500649 100644 --- a/transport/internet/dialer_test.go +++ b/transport/internet/dialer_test.go @@ -17,7 +17,7 @@ func TestDialWithLocalAddr(t *testing.T) { assert.Error(err).IsNil() defer server.Close() - conn, err := DialToDest(v2net.LocalHostIP, v2net.TCPDestination(v2net.LocalHostIP, dest.Port)) + conn, err := DialSystem(v2net.LocalHostIP, v2net.TCPDestination(v2net.LocalHostIP, dest.Port)) assert.Error(err).IsNil() assert.String(conn.RemoteAddr().String()).Equals("127.0.0.1:" + dest.Port.String()) conn.Close() diff --git a/transport/internet/kcp/dialer.go b/transport/internet/kcp/dialer.go index 90cdb5a6e..71a4125be 100644 --- a/transport/internet/kcp/dialer.go +++ b/transport/internet/kcp/dialer.go @@ -8,6 +8,7 @@ import ( "crypto/cipher" + "v2ray.com/core/common" "v2ray.com/core/common/buf" "v2ray.com/core/common/dice" "v2ray.com/core/common/errors" @@ -115,7 +116,7 @@ func DialKCP(src v2net.Address, dest v2net.Destination, options internet.DialerO id := internal.NewConnectionID(src, dest) conn := globalPool.Get(id) if conn == nil { - rawConn, err := internet.DialToDest(src, dest) + rawConn, err := internet.DialSystem(src, dest) if err != nil { log.Error("KCP|Dialer: Failed to dial to dest: ", err) return nil, err @@ -172,5 +173,5 @@ func DialKCP(src v2net.Address, dest v2net.Destination, options internet.DialerO } func init() { - internet.KCPDialer = DialKCP + common.Must(internet.RegisterNetworkDialer(v2net.Network_KCP, DialKCP)) } diff --git a/transport/internet/kcp/listener.go b/transport/internet/kcp/listener.go index dae4dc0d2..0d139e88e 100644 --- a/transport/internet/kcp/listener.go +++ b/transport/internet/kcp/listener.go @@ -9,6 +9,7 @@ import ( "crypto/cipher" + "v2ray.com/core/common" "v2ray.com/core/common/buf" "v2ray.com/core/common/errors" "v2ray.com/core/common/log" @@ -297,5 +298,5 @@ func ListenKCP(address v2net.Address, port v2net.Port, options internet.ListenOp } func init() { - internet.KCPListenFunc = ListenKCP + common.Must(internet.RegisterNetworkListener(v2net.Network_KCP, ListenKCP)) } diff --git a/transport/internet/tcp/dialer.go b/transport/internet/tcp/dialer.go index 095312f36..612f8ba5b 100644 --- a/transport/internet/tcp/dialer.go +++ b/transport/internet/tcp/dialer.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "net" + "v2ray.com/core/common" "v2ray.com/core/common/errors" "v2ray.com/core/common/log" v2net "v2ray.com/core/common/net" @@ -34,7 +35,7 @@ func Dial(src v2net.Address, dest v2net.Destination, options internet.DialerOpti } if conn == nil { var err error - conn, err = internet.DialToDest(src, dest) + conn, err = internet.DialSystem(src, dest) if err != nil { return nil, err } @@ -69,5 +70,5 @@ func Dial(src v2net.Address, dest v2net.Destination, options internet.DialerOpti } func init() { - internet.TCPDialer = Dial + common.Must(internet.RegisterNetworkDialer(v2net.Network_TCP, Dial)) } diff --git a/transport/internet/tcp/hub.go b/transport/internet/tcp/hub.go index 6acba11e3..194457f6f 100644 --- a/transport/internet/tcp/hub.go +++ b/transport/internet/tcp/hub.go @@ -6,6 +6,7 @@ import ( "sync" "time" + "v2ray.com/core/common" "v2ray.com/core/common/errors" "v2ray.com/core/common/log" v2net "v2ray.com/core/common/net" @@ -158,5 +159,5 @@ func (v *TCPListener) Close() error { } func init() { - internet.TCPListenFunc = ListenTCP + common.Must(internet.RegisterNetworkListener(v2net.Network_TCP, ListenTCP)) } diff --git a/transport/internet/tcp_hub.go b/transport/internet/tcp_hub.go index 8d3a76362..a54c50af9 100644 --- a/transport/internet/tcp_hub.go +++ b/transport/internet/tcp_hub.go @@ -13,11 +13,17 @@ import ( var ( ErrClosedConnection = errors.New("Connection already closed.") - KCPListenFunc ListenFunc - TCPListenFunc ListenFunc - WSListenFunc ListenFunc + networkListenerCache = make(map[v2net.Network]ListenFunc) ) +func RegisterNetworkListener(network v2net.Network, listener ListenFunc) error { + if _, found := networkListenerCache[network]; found { + return errors.New("Internet|TCPHub: ", network, " listener already registered.") + } + networkListenerCache[network] = listener + return nil +} + type ListenFunc func(address v2net.Address, port v2net.Port, options ListenOptions) (Listener, error) type ListenOptions struct { Stream *StreamConfig @@ -37,26 +43,16 @@ type TCPHub struct { } func ListenTCP(address v2net.Address, port v2net.Port, callback ConnectionHandler, settings *StreamConfig) (*TCPHub, error) { - var listener Listener - var err error options := ListenOptions{ Stream: settings, } - switch settings.Network { - case v2net.Network_TCP: - listener, err = TCPListenFunc(address, port, options) - case v2net.Network_KCP: - listener, err = KCPListenFunc(address, port, options) - case v2net.Network_WebSocket: - listener, err = WSListenFunc(address, port, options) - default: - log.Error("Internet|Listener: Unknown stream type: ", settings.Network) - err = ErrUnsupportedStreamType + listenFunc := networkListenerCache[settings.Network] + if listenFunc == nil { + return nil, errors.New("Internet|TCPHub: ", settings.Network, " listener not registered.") } - + listener, err := listenFunc(address, port, options) if err != nil { - log.Warning("Internet|Listener: Failed to listen on ", address, ":", port) - return nil, err + return nil, errors.Base(err).Message("Interent|TCPHub: Failed to listen: ") } hub := &TCPHub{ diff --git a/transport/internet/udp/connection.go b/transport/internet/udp/connection.go index fcd1a9664..65399b1db 100644 --- a/transport/internet/udp/connection.go +++ b/transport/internet/udp/connection.go @@ -3,6 +3,7 @@ package udp import ( "net" + "v2ray.com/core/common" v2net "v2ray.com/core/common/net" "v2ray.com/core/transport/internet" ) @@ -18,14 +19,15 @@ func (v *Connection) Reusable() bool { func (v *Connection) SetReusable(b bool) {} func init() { - internet.UDPDialer = func(src v2net.Address, dest v2net.Destination, options internet.DialerOptions) (internet.Connection, error) { - conn, err := internet.DialToDest(src, dest) - if err != nil { - return nil, err - } - // TODO: handle dialer options - return &Connection{ - UDPConn: *(conn.(*net.UDPConn)), - }, nil - } + common.Must(internet.RegisterNetworkDialer(v2net.Network_UDP, + func(src v2net.Address, dest v2net.Destination, options internet.DialerOptions) (internet.Connection, error) { + conn, err := internet.DialSystem(src, dest) + if err != nil { + return nil, err + } + // TODO: handle dialer options + return &Connection{ + UDPConn: *(conn.(*net.UDPConn)), + }, nil + })) } diff --git a/transport/internet/websocket/dialer.go b/transport/internet/websocket/dialer.go index c7a3987ba..fe34b3fa7 100644 --- a/transport/internet/websocket/dialer.go +++ b/transport/internet/websocket/dialer.go @@ -5,6 +5,7 @@ import ( "net" "github.com/gorilla/websocket" + "v2ray.com/core/common" "v2ray.com/core/common/log" v2net "v2ray.com/core/common/net" "v2ray.com/core/transport/internet" @@ -46,7 +47,7 @@ func Dial(src v2net.Address, dest v2net.Destination, options internet.DialerOpti } func init() { - internet.WSDialer = Dial + common.Must(internet.RegisterNetworkDialer(v2net.Network_WebSocket, Dial)) } func wsDial(src v2net.Address, dest v2net.Destination, options internet.DialerOptions) (*wsconn, error) { @@ -57,7 +58,7 @@ func wsDial(src v2net.Address, dest v2net.Destination, options internet.DialerOp wsSettings := networkSettings.(*Config) commonDial := func(network, addr string) (net.Conn, error) { - return internet.DialToDest(src, dest) + return internet.DialSystem(src, dest) } dialer := websocket.Dialer{ diff --git a/transport/internet/websocket/hub.go b/transport/internet/websocket/hub.go index 234d8af04..ba90a8be1 100644 --- a/transport/internet/websocket/hub.go +++ b/transport/internet/websocket/hub.go @@ -8,6 +8,7 @@ import ( "sync" "time" + "v2ray.com/core/common" "v2ray.com/core/common/errors" "v2ray.com/core/common/log" v2net "v2ray.com/core/common/net" @@ -197,5 +198,5 @@ func (v *WSListener) Close() error { } func init() { - internet.WSListenFunc = ListenWS + common.Must(internet.RegisterNetworkListener(v2net.Network_WebSocket, ListenWS)) }