diff --git a/common/net/system.go b/common/net/system.go index e0ac48234..721ad6059 100644 --- a/common/net/system.go +++ b/common/net/system.go @@ -8,6 +8,8 @@ var DialUDP = net.DialUDP var DialUnix = net.DialUnix var Dial = net.Dial +type ListenConfig = net.ListenConfig + var Listen = net.Listen var ListenTCP = net.ListenTCP var ListenUDP = net.ListenUDP @@ -25,6 +27,7 @@ var CIDRMask = net.CIDRMask type Addr = net.Addr type Conn = net.Conn +type PacketConn = net.PacketConn type TCPAddr = net.TCPAddr type TCPConn = net.TCPConn diff --git a/testing/servers/tcp/tcp.go b/testing/servers/tcp/tcp.go index 0077630dc..6f49a34df 100644 --- a/testing/servers/tcp/tcp.go +++ b/testing/servers/tcp/tcp.go @@ -18,7 +18,7 @@ type Server struct { ShouldClose bool SendFirst []byte Listen net.Address - listener *net.TCPListener + listener net.Listener } func (server *Server) Start() (net.Destination, error) { @@ -30,17 +30,19 @@ func (server *Server) StartContext(ctx context.Context) (net.Destination, error) if listenerAddr == nil { listenerAddr = net.LocalHostIP } - listener, err := internet.ListenSystemTCP(ctx, &net.TCPAddr{ + listener, err := internet.ListenSystem(ctx, &net.TCPAddr{ IP: listenerAddr.IP(), Port: int(server.Port), }) if err != nil { return net.Destination{}, err } - server.Port = net.Port(listener.Addr().(*net.TCPAddr).Port) - server.listener = listener - go server.acceptConnections(listener) + localAddr := listener.Addr().(*net.TCPAddr) + server.Port = net.Port(localAddr.Port) + server.listener = listener + go server.acceptConnections(listener.(*net.TCPListener)) + return net.TCPDestination(net.IPAddress(localAddr.IP), net.Port(localAddr.Port)), nil } diff --git a/transport/internet/context.go b/transport/internet/context.go index 4d7323478..9d24dee23 100644 --- a/transport/internet/context.go +++ b/transport/internet/context.go @@ -11,8 +11,6 @@ type key int const ( streamSettingsKey key = iota dialerSrcKey - transportSettingsKey - securitySettingsKey ) func ContextWithStreamSettings(ctx context.Context, streamSettings *MemoryStreamConfig) context.Context { diff --git a/transport/internet/dialer.go b/transport/internet/dialer.go index 5dcf2586c..23633150f 100644 --- a/transport/internet/dialer.go +++ b/transport/internet/dialer.go @@ -41,11 +41,15 @@ func Dial(ctx context.Context, dest net.Destination) (Connection, error) { return dialer(ctx, dest) } - udpDialer := transportDialerCache["udp"] - if udpDialer == nil { - return nil, newError("UDP dialer not registered").AtError() + if dest.Network == net.Network_UDP { + udpDialer := transportDialerCache["udp"] + if udpDialer == nil { + return nil, newError("UDP dialer not registered").AtError() + } + return udpDialer(ctx, dest) } - return udpDialer(ctx, dest) + + return nil, newError("unknown network ", dest.Network) } // DialSystem calls system dialer to create a network connection. diff --git a/transport/internet/http/hub.go b/transport/internet/http/hub.go index d55adfc17..c70458693 100644 --- a/transport/internet/http/hub.go +++ b/transport/internet/http/hub.go @@ -117,7 +117,7 @@ func Listen(ctx context.Context, address net.Address, port net.Port, handler int listener.server = server go func() { - tcpListener, err := internet.ListenSystemTCP(ctx, &net.TCPAddr{ + tcpListener, err := internet.ListenSystem(ctx, &net.TCPAddr{ IP: address.IP(), Port: int(port), }) diff --git a/transport/internet/sockopt.go b/transport/internet/sockopt.go index 48d230ddb..7facf30f4 100644 --- a/transport/internet/sockopt.go +++ b/transport/internet/sockopt.go @@ -8,3 +8,12 @@ func isTCPSocket(network string) bool { return false } } + +func isUDPSocket(network string) bool { + switch network { + case "udp", "udp4", "udp6": + return true + default: + return false + } +} diff --git a/transport/internet/system_dialer.go b/transport/internet/system_dialer.go index 28208c2a7..31f2ab9a2 100644 --- a/transport/internet/system_dialer.go +++ b/transport/internet/system_dialer.go @@ -20,18 +20,26 @@ type SystemDialer interface { type DefaultSystemDialer struct { } +func getSocketSettings(ctx context.Context) *SocketConfig { + streamSettings := StreamSettingsFromContext(ctx) + if streamSettings != nil && streamSettings.SocketSettings != nil { + return streamSettings.SocketSettings + } + + return nil +} + func (DefaultSystemDialer) Dial(ctx context.Context, src net.Address, dest net.Destination) (net.Conn, error) { dialer := &net.Dialer{ Timeout: time.Second * 60, DualStack: true, } - streamSettings := StreamSettingsFromContext(ctx) - if streamSettings != nil && streamSettings.SocketSettings != nil { - config := streamSettings.SocketSettings + sockopts := getSocketSettings(ctx) + if sockopts != nil { dialer.Control = func(network, address string, c syscall.RawConn) error { return c.Control(func(fd uintptr) { - if err := applyOutboundSocketOptions(network, address, fd, config); err != nil { + if err := applyOutboundSocketOptions(network, address, fd, sockopts); err != nil { newError("failed to apply socket options").Base(err).WriteToLog(session.ExportIDToError(ctx)) } }) diff --git a/transport/internet/system_listener.go b/transport/internet/system_listener.go index e9d438e9b..9c894072a 100644 --- a/transport/internet/system_listener.go +++ b/transport/internet/system_listener.go @@ -2,38 +2,48 @@ package internet import ( "context" + "syscall" "v2ray.com/core/common/net" "v2ray.com/core/common/session" ) var ( - effectiveTCPListener = DefaultTCPListener{} + effectiveListener = DefaultListener{} ) -type DefaultTCPListener struct{} +type DefaultListener struct{} -func (tl *DefaultTCPListener) Listen(ctx context.Context, addr *net.TCPAddr) (*net.TCPListener, error) { - l, err := net.ListenTCP("tcp", addr) - if err != nil { - return nil, err - } +func (*DefaultListener) Listen(ctx context.Context, addr net.Addr) (net.Listener, error) { + var lc net.ListenConfig - streamSettings := StreamSettingsFromContext(ctx) - if streamSettings != nil && streamSettings.SocketSettings != nil { - config := streamSettings.SocketSettings - rawConn, err := l.SyscallConn() - if err != nil { - return nil, err - } - if err := rawConn.Control(func(fd uintptr) { - if err := applyInboundSocketOptions("tcp", fd, config); err != nil { - newError("failed to apply socket options to incoming connection").Base(err).WriteToLog(session.ExportIDToError(ctx)) - } - }); err != nil { - return nil, err + sockopt := getSocketSettings(ctx) + if sockopt != nil { + lc.Control = func(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + if err := applyInboundSocketOptions(network, fd, sockopt); err != nil { + newError("failed to apply socket options to incoming connection").Base(err).WriteToLog(session.ExportIDToError(ctx)) + } + }) } } - return l, nil + return lc.Listen(ctx, addr.Network(), addr.String()) +} + +func (*DefaultListener) ListenPacket(ctx context.Context, addr net.Addr) (net.PacketConn, error) { + var lc net.ListenConfig + + sockopt := getSocketSettings(ctx) + if sockopt != nil { + lc.Control = func(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + if err := applyInboundSocketOptions(network, fd, sockopt); err != nil { + newError("failed to apply socket options to incoming connection").Base(err).WriteToLog(session.ExportIDToError(ctx)) + } + }) + } + } + + return lc.ListenPacket(ctx, addr.Network(), addr.String()) } diff --git a/transport/internet/tcp/hub.go b/transport/internet/tcp/hub.go index 57055095a..0221803f2 100644 --- a/transport/internet/tcp/hub.go +++ b/transport/internet/tcp/hub.go @@ -14,7 +14,7 @@ import ( // Listener is an internet.Listener that listens for TCP connections. type Listener struct { - listener *net.TCPListener + listener net.Listener tlsConfig *gotls.Config authConfig internet.ConnectionAuthenticator config *Config @@ -23,7 +23,7 @@ type Listener struct { // ListenTCP creates a new Listener based on configurations. func ListenTCP(ctx context.Context, address net.Address, port net.Port, handler internet.ConnHandler) (internet.Listener, error) { - listener, err := internet.ListenSystemTCP(ctx, &net.TCPAddr{ + listener, err := internet.ListenSystem(ctx, &net.TCPAddr{ IP: address.IP(), Port: int(port), }) diff --git a/transport/internet/tcp_hub.go b/transport/internet/tcp_hub.go index 3c12ac9c1..598a8034f 100644 --- a/transport/internet/tcp_hub.go +++ b/transport/internet/tcp_hub.go @@ -54,6 +54,10 @@ func ListenTCP(ctx context.Context, address net.Address, port net.Port, handler return listener, nil } -func ListenSystemTCP(ctx context.Context, addr *net.TCPAddr) (*net.TCPListener, error) { - return effectiveTCPListener.Listen(ctx, addr) +func ListenSystem(ctx context.Context, addr net.Addr) (net.Listener, error) { + return effectiveListener.Listen(ctx, addr) +} + +func ListenSystemPacket(ctx context.Context, addr net.Addr) (net.PacketConn, error) { + return effectiveListener.ListenPacket(ctx, addr) } diff --git a/transport/internet/websocket/hub.go b/transport/internet/websocket/hub.go index b9d4045b9..bbbbe06f5 100644 --- a/transport/internet/websocket/hub.go +++ b/transport/internet/websocket/hub.go @@ -85,7 +85,7 @@ func ListenWS(ctx context.Context, address net.Address, port net.Port, addConn i } func listenTCP(ctx context.Context, address net.Address, port net.Port, tlsConfig *tls.Config) (net.Listener, error) { - listener, err := internet.ListenSystemTCP(ctx, &net.TCPAddr{ + listener, err := internet.ListenSystem(ctx, &net.TCPAddr{ IP: address.IP(), Port: int(port), })