diff --git a/proxy/http/server.go b/proxy/http/server.go index cfe524292..a144dd0e8 100644 --- a/proxy/http/server.go +++ b/proxy/http/server.go @@ -87,10 +87,11 @@ func parseBasicAuth(auth string) (username, password string, ok bool) { } func (s *Server) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher dispatcher.Interface) error { - reader := bufio.NewReaderSize(conn, 2048) - Start: - conn.SetReadDeadline(time.Now().Add(time.Second * 16)) + reader := bufio.NewReaderSize(conn, buf.Size) + if err := conn.SetReadDeadline(time.Now().Add(time.Second * 16)); err != nil { + return newError("unable to set read deadtime").Base(err) + } request, err := http.ReadRequest(reader) if err != nil { @@ -110,7 +111,9 @@ Start: } log.Trace(newError("request to Method [", request.Method, "] Host [", request.Host, "] with URL [", request.URL, "]")) - conn.SetReadDeadline(time.Time{}) + if err := conn.SetReadDeadline(time.Time{}); err != nil { + log.Trace(newError("unable to set back read deadline").Base(err)) + } defaultPort := net.Port(80) if strings.ToLower(request.URL.Scheme) == "https" { @@ -126,13 +129,23 @@ Start: } log.Access(conn.RemoteAddr(), request.URL, log.AccessAccepted, "") + // Get rid of bufio.Reader. + var firstPayload *buf.Buffer + if reader.Buffered() > 0 { + firstPayload = buf.New() + common.Must(firstPayload.Reset(func(b []byte) (int, error) { + return reader.Read(b[:reader.Buffered()]) + })) + } + reader = nil + if strings.ToUpper(request.Method) == "CONNECT" { - return s.handleConnect(ctx, request, reader, conn, dest, dispatcher) + return s.handleConnect(ctx, request, firstPayload, conn, conn, dest, dispatcher) } keepAlive := (strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive") - err = s.handlePlainHTTP(ctx, request, reader, conn, dest, dispatcher) + err = s.handlePlainHTTP(ctx, request, firstPayload, conn, conn, dest, dispatcher) if err == errWaitAnother { if keepAlive { goto Start @@ -143,7 +156,7 @@ Start: return err } -func (s *Server) handleConnect(ctx context.Context, request *http.Request, reader io.Reader, writer io.Writer, dest net.Destination, dispatcher dispatcher.Interface) error { +func (s *Server) handleConnect(ctx context.Context, request *http.Request, payload *buf.Buffer, reader io.Reader, writer io.Writer, dest net.Destination, dispatcher dispatcher.Interface) error { _, err := writer.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n")) if err != nil { return newError("failed to write back OK response").Base(err) @@ -160,14 +173,17 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade return err } + if !payload.IsEmpty() { + if err := ray.InboundInput().WriteMultiBuffer(buf.NewMultiBufferValue(payload)); err != nil { + return err + } + } + requestDone := signal.ExecuteAsync(func() error { defer ray.InboundInput().Close() v2reader := buf.NewReader(reader) - if err := buf.Copy(v2reader, ray.InboundInput(), buf.UpdateActivity(timer)); err != nil { - return err - } - return nil + return buf.Copy(v2reader, ray.InboundInput(), buf.UpdateActivity(timer)) }) responseDone := signal.ExecuteAsync(func() error { @@ -219,7 +235,7 @@ func StripHopByHopHeaders(header http.Header) { var errWaitAnother = newError("keep alive") -func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, reader io.Reader, writer io.Writer, dest net.Destination, dispatcher dispatcher.Interface) error { +func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, payload *buf.Buffer, reader io.Reader, writer io.Writer, dest net.Destination, dispatcher dispatcher.Interface) error { if !s.config.AllowTransparent && len(request.URL.Host) <= 0 { // RFC 2068 (HTTP/1.1) requires URL to be absolute URL in HTTP proxy. response := &http.Response{ @@ -251,6 +267,12 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, rea output := ray.InboundOutput() defer input.Close() + if !payload.IsEmpty() { + if err := input.WriteMultiBuffer(buf.NewMultiBufferValue(payload)); err != nil { + return err + } + } + var result error = errWaitAnother requestDone := signal.ExecuteAsync(func() error {