mirror of
https://github.com/v2fly/v2ray-core.git
synced 2025-01-02 23:47:07 -05:00
Fixed HTTP response not adjusted based on request
This commit is contained in:
parent
38e89bd2c7
commit
087a62ef3d
@ -3,6 +3,7 @@ package http
|
||||
//go:generate errorgen
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
@ -28,6 +29,8 @@ const (
|
||||
|
||||
var (
|
||||
ErrHeaderToLong = newError("Header too long.")
|
||||
|
||||
ErrHeaderMisMatch = newError("Header Mismatch.")
|
||||
)
|
||||
|
||||
type Reader interface {
|
||||
@ -51,12 +54,22 @@ func (NoOpWriter) Write(io.Writer) error {
|
||||
}
|
||||
|
||||
type HeaderReader struct {
|
||||
req *http.Request
|
||||
expectedHeader *RequestConfig
|
||||
}
|
||||
|
||||
func (*HeaderReader) Read(reader io.Reader) (*buf.Buffer, error) {
|
||||
func (h *HeaderReader) ExpectThisRequest(expectedHeader *RequestConfig) *HeaderReader {
|
||||
h.expectedHeader = expectedHeader
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *HeaderReader) Read(reader io.Reader) (*buf.Buffer, error) {
|
||||
buffer := buf.New()
|
||||
totalBytes := int32(0)
|
||||
endingDetected := false
|
||||
|
||||
var headerBuf bytes.Buffer
|
||||
|
||||
for totalBytes < maxHeaderLength {
|
||||
_, err := buffer.ReadFrom(reader)
|
||||
if err != nil {
|
||||
@ -64,6 +77,7 @@ func (*HeaderReader) Read(reader io.Reader) (*buf.Buffer, error) {
|
||||
return nil, err
|
||||
}
|
||||
if n := bytes.Index(buffer.Bytes(), []byte(ENDING)); n != -1 {
|
||||
headerBuf.Write(buffer.BytesRange(0, int32(n+len(ENDING))))
|
||||
buffer.Advance(int32(n + len(ENDING)))
|
||||
endingDetected = true
|
||||
break
|
||||
@ -71,19 +85,52 @@ func (*HeaderReader) Read(reader io.Reader) (*buf.Buffer, error) {
|
||||
lenEnding := int32(len(ENDING))
|
||||
if buffer.Len() >= lenEnding {
|
||||
totalBytes += buffer.Len() - lenEnding
|
||||
headerBuf.Write(buffer.BytesRange(0, buffer.Len()-lenEnding))
|
||||
leftover := buffer.BytesFrom(-lenEnding)
|
||||
buffer.Clear()
|
||||
copy(buffer.Extend(lenEnding), leftover)
|
||||
}
|
||||
}
|
||||
if buffer.IsEmpty() {
|
||||
buffer.Release()
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if !endingDetected {
|
||||
buffer.Release()
|
||||
return nil, ErrHeaderToLong
|
||||
}
|
||||
|
||||
if h.expectedHeader == nil {
|
||||
if buffer.IsEmpty() {
|
||||
buffer.Release()
|
||||
return nil, nil
|
||||
}
|
||||
return buffer, nil
|
||||
}
|
||||
|
||||
//Parse the request
|
||||
|
||||
if req, err := readRequest(bufio.NewReader(bytes.NewReader(headerBuf.Bytes())), false); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
h.req = req
|
||||
}
|
||||
|
||||
//Check req
|
||||
path := h.req.URL.Path
|
||||
hasThisUri := false
|
||||
for _, u := range h.expectedHeader.Uri {
|
||||
if u == path {
|
||||
hasThisUri = true
|
||||
}
|
||||
}
|
||||
|
||||
if hasThisUri == false {
|
||||
return nil, ErrHeaderMisMatch
|
||||
}
|
||||
|
||||
if buffer.IsEmpty() {
|
||||
buffer.Release()
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return buffer, nil
|
||||
}
|
||||
|
||||
@ -110,18 +157,24 @@ func (w *HeaderWriter) Write(writer io.Writer) error {
|
||||
type HttpConn struct {
|
||||
net.Conn
|
||||
|
||||
readBuffer *buf.Buffer
|
||||
oneTimeReader Reader
|
||||
oneTimeWriter Writer
|
||||
errorWriter Writer
|
||||
readBuffer *buf.Buffer
|
||||
oneTimeReader Reader
|
||||
oneTimeWriter Writer
|
||||
errorWriter Writer
|
||||
errorMismatchWriter Writer
|
||||
errorTooLongWriter Writer
|
||||
|
||||
errReason error
|
||||
}
|
||||
|
||||
func NewHttpConn(conn net.Conn, reader Reader, writer Writer, errorWriter Writer) *HttpConn {
|
||||
func NewHttpConn(conn net.Conn, reader Reader, writer Writer, errorWriter Writer, errorMismatchWriter Writer, errorTooLongWriter Writer) *HttpConn {
|
||||
return &HttpConn{
|
||||
Conn: conn,
|
||||
oneTimeReader: reader,
|
||||
oneTimeWriter: writer,
|
||||
errorWriter: errorWriter,
|
||||
Conn: conn,
|
||||
oneTimeReader: reader,
|
||||
oneTimeWriter: writer,
|
||||
errorWriter: errorWriter,
|
||||
errorMismatchWriter: errorMismatchWriter,
|
||||
errorTooLongWriter: errorTooLongWriter,
|
||||
}
|
||||
}
|
||||
|
||||
@ -129,6 +182,7 @@ func (c *HttpConn) Read(b []byte) (int, error) {
|
||||
if c.oneTimeReader != nil {
|
||||
buffer, err := c.oneTimeReader.Read(c.Conn)
|
||||
if err != nil {
|
||||
c.errReason = err
|
||||
return 0, err
|
||||
}
|
||||
c.readBuffer = buffer
|
||||
@ -165,7 +219,16 @@ func (c *HttpConn) Close() error {
|
||||
if c.oneTimeWriter != nil && c.errorWriter != nil {
|
||||
// Connection is being closed but header wasn't sent. This means the client request
|
||||
// is probably not valid. Sending back a server error header in this case.
|
||||
c.errorWriter.Write(c.Conn)
|
||||
|
||||
//Write response based on error reason
|
||||
|
||||
if c.errReason == ErrHeaderMisMatch {
|
||||
c.errorMismatchWriter.Write(c.Conn)
|
||||
} else if c.errReason == ErrHeaderToLong {
|
||||
c.errorTooLongWriter.Write(c.Conn)
|
||||
} else {
|
||||
c.errorWriter.Write(c.Conn)
|
||||
}
|
||||
}
|
||||
|
||||
return c.Conn.Close()
|
||||
@ -230,36 +293,17 @@ func (a HttpAuthenticator) Client(conn net.Conn) net.Conn {
|
||||
if a.config.Response != nil {
|
||||
writer = a.GetClientWriter()
|
||||
}
|
||||
return NewHttpConn(conn, reader, writer, NoOpWriter{})
|
||||
return NewHttpConn(conn, reader, writer, NoOpWriter{}, NoOpWriter{}, NoOpWriter{})
|
||||
}
|
||||
|
||||
func (a HttpAuthenticator) Server(conn net.Conn) net.Conn {
|
||||
if a.config.Request == nil && a.config.Response == nil {
|
||||
return conn
|
||||
}
|
||||
return NewHttpConn(conn, new(HeaderReader), a.GetServerWriter(), formResponseHeader(&ResponseConfig{
|
||||
Version: &Version{
|
||||
Value: "1.1",
|
||||
},
|
||||
Status: &Status{
|
||||
Code: "500",
|
||||
Reason: "Internal Server Error",
|
||||
},
|
||||
Header: []*Header{
|
||||
{
|
||||
Name: "Connection",
|
||||
Value: []string{"close"},
|
||||
},
|
||||
{
|
||||
Name: "Cache-Control",
|
||||
Value: []string{"private"},
|
||||
},
|
||||
{
|
||||
Name: "Content-Length",
|
||||
Value: []string{"0"},
|
||||
},
|
||||
},
|
||||
}))
|
||||
return NewHttpConn(conn, new(HeaderReader).ExpectThisRequest(a.config.Request), a.GetServerWriter(),
|
||||
formResponseHeader(resp400),
|
||||
formResponseHeader(resp404),
|
||||
formResponseHeader(resp400))
|
||||
}
|
||||
|
||||
func NewHttpAuthenticator(ctx context.Context, config *Config) (HttpAuthenticator, error) {
|
||||
|
@ -1,9 +1,12 @@
|
||||
package http_test
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -28,10 +31,15 @@ func TestReaderWriter(t *testing.T) {
|
||||
|
||||
reader := &HeaderReader{}
|
||||
buffer, err := reader.Read(cache)
|
||||
common.Must(err)
|
||||
if buffer.String() != "efg" {
|
||||
t.Error("buffer: ", buffer.String())
|
||||
if err != nil && !strings.HasPrefix(err.Error(), "malformed HTTP request") {
|
||||
t.Error("unknown error ", err)
|
||||
}
|
||||
_ = buffer
|
||||
return
|
||||
/*
|
||||
if buffer.String() != "efg" {
|
||||
t.Error("buffer: ", buffer.String())
|
||||
}*/
|
||||
}
|
||||
|
||||
func TestRequestHeader(t *testing.T) {
|
||||
@ -65,10 +73,16 @@ func TestLongRequestHeader(t *testing.T) {
|
||||
|
||||
reader := HeaderReader{}
|
||||
b, err := reader.Read(bytes.NewReader(payload))
|
||||
common.Must(err)
|
||||
if b.String() != "abcd" {
|
||||
t.Error("expect content abcd, but actually ", b.String())
|
||||
|
||||
if err != nil && !(strings.HasPrefix(err.Error(), "invalid") || strings.HasPrefix(err.Error(), "malformed")) {
|
||||
t.Error("unknown error ", err)
|
||||
}
|
||||
_ = b
|
||||
/*
|
||||
common.Must(err)
|
||||
if b.String() != "abcd" {
|
||||
t.Error("expect content abcd, but actually ", b.String())
|
||||
}*/
|
||||
}
|
||||
|
||||
func TestConnection(t *testing.T) {
|
||||
@ -143,3 +157,162 @@ func TestConnection(t *testing.T) {
|
||||
t.Error("response: ", string(actualResponse[:totalBytes]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionInvPath(t *testing.T) {
|
||||
auth, err := NewHttpAuthenticator(context.Background(), &Config{
|
||||
Request: &RequestConfig{
|
||||
Method: &Method{Value: "Post"},
|
||||
Uri: []string{"/testpath"},
|
||||
Header: []*Header{
|
||||
{
|
||||
Name: "Host",
|
||||
Value: []string{"www.v2ray.com", "www.google.com"},
|
||||
},
|
||||
{
|
||||
Name: "User-Agent",
|
||||
Value: []string{"Test-Agent"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Response: &ResponseConfig{
|
||||
Version: &Version{
|
||||
Value: "1.1",
|
||||
},
|
||||
Status: &Status{
|
||||
Code: "404",
|
||||
Reason: "Not Found",
|
||||
},
|
||||
},
|
||||
})
|
||||
common.Must(err)
|
||||
|
||||
authR, err := NewHttpAuthenticator(context.Background(), &Config{
|
||||
Request: &RequestConfig{
|
||||
Method: &Method{Value: "Post"},
|
||||
Uri: []string{"/testpathErr"},
|
||||
Header: []*Header{
|
||||
{
|
||||
Name: "Host",
|
||||
Value: []string{"www.v2ray.com", "www.google.com"},
|
||||
},
|
||||
{
|
||||
Name: "User-Agent",
|
||||
Value: []string{"Test-Agent"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Response: &ResponseConfig{
|
||||
Version: &Version{
|
||||
Value: "1.1",
|
||||
},
|
||||
Status: &Status{
|
||||
Code: "404",
|
||||
Reason: "Not Found",
|
||||
},
|
||||
},
|
||||
})
|
||||
common.Must(err)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
common.Must(err)
|
||||
|
||||
go func() {
|
||||
conn, err := listener.Accept()
|
||||
common.Must(err)
|
||||
authConn := auth.Server(conn)
|
||||
b := make([]byte, 256)
|
||||
for {
|
||||
n, err := authConn.Read(b)
|
||||
if err != nil {
|
||||
authConn.Close()
|
||||
break
|
||||
}
|
||||
_, err = authConn.Write(b[:n])
|
||||
common.Must(err)
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := net.DialTCP("tcp", nil, listener.Addr().(*net.TCPAddr))
|
||||
common.Must(err)
|
||||
|
||||
authConn := authR.Client(conn)
|
||||
defer authConn.Close()
|
||||
|
||||
authConn.Write([]byte("Test payload"))
|
||||
authConn.Write([]byte("Test payload 2"))
|
||||
|
||||
expectedResponse := "Test payloadTest payload 2"
|
||||
actualResponse := make([]byte, 256)
|
||||
deadline := time.Now().Add(time.Second * 5)
|
||||
totalBytes := 0
|
||||
for {
|
||||
n, err := authConn.Read(actualResponse[totalBytes:])
|
||||
if err != io.EOF {
|
||||
t.Error("Unexpected Error", err)
|
||||
}
|
||||
totalBytes += n
|
||||
if totalBytes >= len(expectedResponse) || time.Now().After(deadline) {
|
||||
break
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func TestConnectionInvReq(t *testing.T) {
|
||||
auth, err := NewHttpAuthenticator(context.Background(), &Config{
|
||||
Request: &RequestConfig{
|
||||
Method: &Method{Value: "Post"},
|
||||
Uri: []string{"/testpath"},
|
||||
Header: []*Header{
|
||||
{
|
||||
Name: "Host",
|
||||
Value: []string{"www.v2ray.com", "www.google.com"},
|
||||
},
|
||||
{
|
||||
Name: "User-Agent",
|
||||
Value: []string{"Test-Agent"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Response: &ResponseConfig{
|
||||
Version: &Version{
|
||||
Value: "1.1",
|
||||
},
|
||||
Status: &Status{
|
||||
Code: "404",
|
||||
Reason: "Not Found",
|
||||
},
|
||||
},
|
||||
})
|
||||
common.Must(err)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
common.Must(err)
|
||||
|
||||
go func() {
|
||||
conn, err := listener.Accept()
|
||||
common.Must(err)
|
||||
authConn := auth.Server(conn)
|
||||
b := make([]byte, 256)
|
||||
for {
|
||||
n, err := authConn.Read(b)
|
||||
if err != nil {
|
||||
authConn.Close()
|
||||
break
|
||||
}
|
||||
_, err = authConn.Write(b[:n])
|
||||
common.Must(err)
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := net.DialTCP("tcp", nil, listener.Addr().(*net.TCPAddr))
|
||||
common.Must(err)
|
||||
|
||||
conn.Write([]byte("ABCDEFGHIJKMLN\r\n\r\n"))
|
||||
l, _, err := bufio.NewReader(conn).ReadLine()
|
||||
common.Must(err)
|
||||
if !strings.HasPrefix(string(l), "HTTP/1.1 400 Bad Request") {
|
||||
t.Error("Resp to non http conn", string(l))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
11
transport/internet/headers/http/linkedreadRequest.go
Normal file
11
transport/internet/headers/http/linkedreadRequest.go
Normal file
@ -0,0 +1,11 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net/http"
|
||||
|
||||
_ "unsafe" // required to use //go:linkname
|
||||
)
|
||||
|
||||
//go:linkname readRequest net/http.readRequest
|
||||
func readRequest(b *bufio.Reader, deleteHostHeader bool) (req *http.Request, err error)
|
49
transport/internet/headers/http/resp.go
Normal file
49
transport/internet/headers/http/resp.go
Normal file
@ -0,0 +1,49 @@
|
||||
package http
|
||||
|
||||
var resp400 = &ResponseConfig{
|
||||
Version: &Version{
|
||||
Value: "1.1",
|
||||
},
|
||||
Status: &Status{
|
||||
Code: "400",
|
||||
Reason: "Bad Request",
|
||||
},
|
||||
Header: []*Header{
|
||||
{
|
||||
Name: "Connection",
|
||||
Value: []string{"close"},
|
||||
},
|
||||
{
|
||||
Name: "Cache-Control",
|
||||
Value: []string{"private"},
|
||||
},
|
||||
{
|
||||
Name: "Content-Length",
|
||||
Value: []string{"0"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var resp404 = &ResponseConfig{
|
||||
Version: &Version{
|
||||
Value: "1.1",
|
||||
},
|
||||
Status: &Status{
|
||||
Code: "404",
|
||||
Reason: "Not Found",
|
||||
},
|
||||
Header: []*Header{
|
||||
{
|
||||
Name: "Connection",
|
||||
Value: []string{"close"},
|
||||
},
|
||||
{
|
||||
Name: "Cache-Control",
|
||||
Value: []string{"private"},
|
||||
},
|
||||
{
|
||||
Name: "Content-Length",
|
||||
Value: []string{"0"},
|
||||
},
|
||||
},
|
||||
}
|
Loading…
Reference in New Issue
Block a user