diff --git a/proxy/http/client.go b/proxy/http/client.go index 2115572a6..318ff3327 100644 --- a/proxy/http/client.go +++ b/proxy/http/client.go @@ -4,11 +4,13 @@ package http import ( "bufio" - "io" "context" "encoding/base64" + "io" "net/http" "strings" + "sync" + "time" "v2ray.com/core" "v2ray.com/core/common" @@ -123,13 +125,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter return nil } -type tunnelConn struct { - internet.Connection - header *buf.Buffer -} - // setUpHTTPTunnel will create a socket tunnel via HTTP CONNECT method -func setUpHTTPTunnel(conn internet.Connection, destination *net.Destination, user *protocol.MemoryUser) *tunnelConn { +func setUpHTTPTunnel(conn internet.Connection, destination *net.Destination, user *protocol.MemoryUser) *tunConn { var headers []string destNetAddr := destination.NetAddr() headers = append(headers, "CONNECT "+destNetAddr+" HTTP/1.1") @@ -143,40 +140,59 @@ func setUpHTTPTunnel(conn internet.Connection, destination *net.Destination, use b := buf.New() b.WriteString(strings.Join(headers, "\r\n") + "\r\n\r\n") - return &tunnelConn { - Connection: conn, - header: b, - } + return newTunConn(conn, b, 5 * time.Millisecond) } -func (c *tunnelConn) Write(b []byte) (n int, err error) { +// tunConn is a connection that writes header before content, +// the header will be written during the next Write call or after +// specified delay. +type tunConn struct { + internet.Connection + header *buf.Buffer + once sync.Once + timer *time.Timer +} + +func newTunConn(conn internet.Connection, header *buf.Buffer, delay time.Duration) *tunConn { + tc := &tunConn{ + Connection: conn, + header: header, + } + if delay > 0 { + tc.timer = time.AfterFunc(delay, func() { + tc.Write([]byte{}) + }) + } + return tc +} + +func (c *tunConn) Write(b []byte) (n int, err error) { + // fallback to normal write if header is sent if c.header == nil { return c.Connection.Write(b) } - buffer := c.header - lenheader := c.header.Len() - // Concate header and b - _, err = buffer.Write(b) - if err != nil { - c.header.Resize(0, lenheader) - return 0, err - } - // Write buffer - nc, err := io.Copy(c.Connection, buffer) - if int32(nc) < lenheader { - c.header.Resize(int32(nc), lenheader) - return 0, err - } - c.header.Release() - c.header = nil - n = int(nc) - int(lenheader) - if err != nil { - return n, err - } - // Write trailing bytes - if n < len(b) { + // Prevent timer and writer race condition + c.once.Do(func() { + if c.timer != nil { + c.timer.Stop() + c.timer = nil + } + lenheader := c.header.Len() + // Concate header and b + common.Must2(c.header.Write(b)) + // Write buffer + var nc int64 + nc, err = io.Copy(c.Connection, c.header) + c.header.Release() + c.header = nil + n = int(nc) - int(lenheader) + if n < 0 { n = 0 } + b = b[n:] + }) + // Write Trailing bytes + if len(b) > 0 && err == nil { var nw int - nw, err = c.Connection.Write(b[:n]) + nw, err = c.Connection.Write(b) n += nw } return n, err