From a5caa01cb6224c84618149829876e10b15b49f73 Mon Sep 17 00:00:00 2001 From: Anonymous-Someneese Date: Thu, 2 Jan 2020 21:09:33 +0800 Subject: [PATCH] Optimize HTTP tunnel setup in TFO environment --- proxy/http/client.go | 51 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 43 insertions(+), 8 deletions(-) diff --git a/proxy/http/client.go b/proxy/http/client.go index 740865a76..2115572a6 100644 --- a/proxy/http/client.go +++ b/proxy/http/client.go @@ -4,9 +4,9 @@ package http import ( "bufio" + "io" "context" "encoding/base64" - "io" "net/http" "strings" @@ -93,9 +93,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter p = c.policyManager.ForLevel(user.Level) } - if err := setUpHTTPTunnel(conn, &destination, user); err != nil { - return err - } + conn = setUpHTTPTunnel(conn, &destination, user) ctx, cancel := context.WithCancel(ctx) timer := signal.CancelAfterInactivity(ctx, cancel, p.Timeouts.ConnectionIdle) @@ -125,8 +123,13 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter return nil } +type tunnelConn struct { + internet.Connection + header *buf.Buffer +} + // setUpHTTPTunnel will create a socket tunnel via HTTP CONNECT method -func setUpHTTPTunnel(writer io.Writer, destination *net.Destination, user *protocol.MemoryUser) error { +func setUpHTTPTunnel(conn internet.Connection, destination *net.Destination, user *protocol.MemoryUser) *tunnelConn { var headers []string destNetAddr := destination.NetAddr() headers = append(headers, "CONNECT "+destNetAddr+" HTTP/1.1") @@ -140,11 +143,43 @@ func setUpHTTPTunnel(writer io.Writer, destination *net.Destination, user *proto b := buf.New() b.WriteString(strings.Join(headers, "\r\n") + "\r\n\r\n") - if err := buf.WriteAllBytes(writer, b.Bytes()); err != nil { - return err + return &tunnelConn { + Connection: conn, + header: b, } +} - return nil +func (c *tunnelConn) Write(b []byte) (n int, err error) { + if c.header == nil { + return c.Connection.Write(b) + } + buffer := c.header + lenheader := c.header.Len() + // Concate header and b + _, err = buffer.Write(b) + if err != nil { + c.header.Resize(0, lenheader) + return 0, err + } + // Write buffer + nc, err := io.Copy(c.Connection, buffer) + if int32(nc) < lenheader { + c.header.Resize(int32(nc), lenheader) + return 0, err + } + c.header.Release() + c.header = nil + n = int(nc) - int(lenheader) + if err != nil { + return n, err + } + // Write trailing bytes + if n < len(b) { + var nw int + nw, err = c.Connection.Write(b[:n]) + n += nw + } + return n, err } func init() {