diff --git a/transport/internet/websocket/connection.go b/transport/internet/websocket/connection.go index 2009db7cd..123621aa5 100644 --- a/transport/internet/websocket/connection.go +++ b/transport/internet/websocket/connection.go @@ -17,7 +17,7 @@ var ( // connection is a wrapper for net.Conn over WebSocket connection. type connection struct { - wsc *websocket.Conn + conn *websocket.Conn reader io.Reader mergingReader buf.Reader @@ -26,7 +26,7 @@ type connection struct { func newConnection(conn *websocket.Conn) *connection { return &connection{ - wsc: conn, + conn: conn, } } @@ -59,7 +59,7 @@ func (c *connection) getReader() (io.Reader, error) { return c.reader, nil } - _, reader, err := c.wsc.NextReader() + _, reader, err := c.conn.NextReader() if err != nil { return nil, err } @@ -69,7 +69,7 @@ func (c *connection) getReader() (io.Reader, error) { // Write implements io.Writer. func (c *connection) Write(b []byte) (int, error) { - if err := c.wsc.WriteMessage(websocket.BinaryMessage, b); err != nil { + if err := c.conn.WriteMessage(websocket.BinaryMessage, b); err != nil { return 0, err } return len(b), nil @@ -83,16 +83,16 @@ func (c *connection) WriteMultiBuffer(mb buf.MultiBuffer) error { } func (c *connection) Close() error { - c.wsc.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5)) - return c.wsc.Close() + c.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5)) + return c.conn.Close() } func (c *connection) LocalAddr() net.Addr { - return c.wsc.LocalAddr() + return c.conn.LocalAddr() } func (c *connection) RemoteAddr() net.Addr { - return c.wsc.RemoteAddr() + return c.conn.RemoteAddr() } func (c *connection) SetDeadline(t time.Time) error { @@ -103,9 +103,9 @@ func (c *connection) SetDeadline(t time.Time) error { } func (c *connection) SetReadDeadline(t time.Time) error { - return c.wsc.SetReadDeadline(t) + return c.conn.SetReadDeadline(t) } func (c *connection) SetWriteDeadline(t time.Time) error { - return c.wsc.SetWriteDeadline(t) + return c.conn.SetWriteDeadline(t) } diff --git a/transport/internet/websocket/dialer.go b/transport/internet/websocket/dialer.go index cdeb95650..c42c4c9a4 100644 --- a/transport/internet/websocket/dialer.go +++ b/transport/internet/websocket/dialer.go @@ -68,7 +68,5 @@ func dialWebsocket(ctx context.Context, dest net.Destination) (net.Conn, error) return nil, newError("failed to dial to (", uri, "): ", reason).Base(err) } - return &connection{ - wsc: conn, - }, nil + return newConnection(conn), nil } diff --git a/transport/internet/websocket/hub.go b/transport/internet/websocket/hub.go index 9183688df..048df8152 100644 --- a/transport/internet/websocket/hub.go +++ b/transport/internet/websocket/hub.go @@ -90,19 +90,24 @@ func (ln *Listener) listenws(address net.Address, port net.Port) error { ln.listener = listener go func() { - http.Serve(listener, &requestHandler{ + err := http.Serve(listener, &requestHandler{ path: ln.config.GetNormailzedPath(), ln: ln, }) + if err != nil { + log.Trace(newError("failed to serve http for WebSocket").Base(err).AtWarning()) + } }() return nil } +// Addr implements net.Listener.Addr(). func (ln *Listener) Addr() net.Addr { return ln.listener.Addr() } +// Close implements net.Listener.Close(). func (ln *Listener) Close() error { return ln.listener.Close() } diff --git a/transport/internet/websocket/ws_test.go b/transport/internet/websocket/ws_test.go index 2501f5ad0..49746ed20 100644 --- a/transport/internet/websocket/ws_test.go +++ b/transport/internet/websocket/ws_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "v2ray.com/core/common" "v2ray.com/core/common/net" "v2ray.com/core/testing/assert" tlsgen "v2ray.com/core/testing/tls" @@ -88,7 +89,7 @@ func Test_listenWSAndDial_TLS(t *testing.T) { }) listen, err := ListenWS(ctx, net.DomainAddress("localhost"), 13143, func(ctx context.Context, conn internet.Connection) bool { go func() { - conn.Close() + common.Must(conn.Close()) }() return true }) @@ -97,5 +98,5 @@ func Test_listenWSAndDial_TLS(t *testing.T) { conn, err := Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13143)) assert.Error(err).IsNil() - conn.Close() + common.Must(conn.Close()) }