diff --git a/transport/internet/internal/sysfd.go b/transport/internet/internal/sysfd.go index e51e00e85..d2720aed2 100644 --- a/transport/internet/internal/sysfd.go +++ b/transport/internet/internal/sysfd.go @@ -6,7 +6,7 @@ import ( ) var ( - errInvalidConn = newError("Invalid Connection.") + errInvalidConn = newError("not a net.Conn") ) // GetSysFd returns the underlying fd of a connection. diff --git a/transport/internet/tcp/sockopt_linux.go b/transport/internet/tcp/sockopt_linux.go index 3e1fe56ba..29aeb4c31 100644 --- a/transport/internet/tcp/sockopt_linux.go +++ b/transport/internet/tcp/sockopt_linux.go @@ -3,33 +3,30 @@ package tcp import ( + "net" "syscall" "v2ray.com/core/app/log" - "v2ray.com/core/common/net" + v2net "v2ray.com/core/common/net" "v2ray.com/core/transport/internet" + "v2ray.com/core/transport/internet/internal" ) const SO_ORIGINAL_DST = 80 -func GetOriginalDestination(conn internet.Connection) net.Destination { - tcpConn, ok := conn.(internet.SysFd) - if !ok { - log.Trace(newError("failed to get sys fd")) - return net.Destination{} - } - fd, err := tcpConn.SysFd() +func GetOriginalDestination(conn internet.Connection) v2net.Destination { + fd, err := internal.GetSysFd(conn.(net.Conn)) if err != nil { log.Trace(newError("failed to get original destination").Base(err)) - return net.Destination{} + return v2net.Destination{} } addr, err := syscall.GetsockoptIPv6Mreq(fd, syscall.IPPROTO_IP, SO_ORIGINAL_DST) if err != nil { log.Trace(newError("failed to call getsockopt").Base(err)) - return net.Destination{} + return v2net.Destination{} } - ip := net.IPAddress(addr.Multiaddr[4:8]) + ip := v2net.IPAddress(addr.Multiaddr[4:8]) port := uint16(addr.Multiaddr[2])<<8 + uint16(addr.Multiaddr[3]) - return net.TCPDestination(ip, net.Port(port)) + return v2net.TCPDestination(ip, v2net.Port(port)) } diff --git a/transport/internet/tcp/sockopt_linux_test.go b/transport/internet/tcp/sockopt_linux_test.go new file mode 100644 index 000000000..5ea28b410 --- /dev/null +++ b/transport/internet/tcp/sockopt_linux_test.go @@ -0,0 +1,28 @@ +// +build linux + +package tcp_test + +import ( + "context" + "testing" + + "v2ray.com/core/testing/assert" + "v2ray.com/core/testing/servers/tcp" +) + +func TestGetOriginalDestination(t *testing.T) { + assert := assert.On(t) + + tcpServer := tcp.Server{ + MsgProcessor: xor, + } + dest, err := tcpServer.Start() + assert.Error(err).IsNil() + defer tcpServer.Close() + + conn, err := Dial(context.Background(), dest) + assert.Error(err).IsNil() + + _, err := GetOriginalDestination(conn) + assert.String(err.Error()).Contains("failed to call getsockopt") +}