diff --git a/common/crypto/auth.go b/common/crypto/auth.go index 5f911d902..0e38e58f1 100644 --- a/common/crypto/auth.go +++ b/common/crypto/auth.go @@ -98,6 +98,9 @@ func (v *AuthenticationReader) NextChunk() error { if size == v.auth.Overhead() { return io.EOF } + if size < v.auth.Overhead() { + return errors.New("AuthenticationReader: invalid packet size.") + } cipherChunk := v.buffer.BytesRange(2, size+2) plainChunk, err := v.auth.Open(cipherChunk, cipherChunk) if err != nil { @@ -176,7 +179,7 @@ func (v *AuthenticationWriter) Write(b []byte) (int, error) { if err != nil { return 0, err } - serial.Uint16ToBytes(uint16(len(cipherChunk)), b[:0]) + serial.Uint16ToBytes(uint16(len(cipherChunk)), v.buffer[:0]) _, err = v.writer.Write(v.buffer[:2+len(cipherChunk)]) return len(b), err } diff --git a/common/io/writer.go b/common/io/writer.go index cf255f6a6..f6bb86c65 100644 --- a/common/io/writer.go +++ b/common/io/writer.go @@ -29,11 +29,14 @@ func NewAdaptiveWriter(writer io.Writer) *AdaptiveWriter { // Write implements Writer.Write(). Write() takes ownership of the given buffer. func (v *AdaptiveWriter) Write(buffer *alloc.Buffer) error { defer buffer.Release() - for !buffer.IsEmpty() { + for { nBytes, err := v.writer.Write(buffer.Bytes()) if err != nil { return err } + if nBytes == buffer.Len() { + break + } buffer.SliceFrom(nBytes) } return nil diff --git a/common/protocol/headers.go b/common/protocol/headers.go index 1c38c5df1..c72e3c52e 100644 --- a/common/protocol/headers.go +++ b/common/protocol/headers.go @@ -31,13 +31,20 @@ func (v *RequestOption) Clear(option RequestOption) { *v = (*v & (^option)) } +type Security byte + +func (v Security) Is(t SecurityType) bool { + return v == Security(t) +} + type RequestHeader struct { - Version byte - User *User - Command RequestCommand - Option RequestOption - Address v2net.Address - Port v2net.Port + Version byte + User *User + Command RequestCommand + Option RequestOption + Security Security + Address v2net.Address + Port v2net.Port } func (v *RequestHeader) Destination() v2net.Destination { diff --git a/proxy/vmess/encoding/auth.go b/proxy/vmess/encoding/auth.go index c2dbd3009..22e22b475 100644 --- a/proxy/vmess/encoding/auth.go +++ b/proxy/vmess/encoding/auth.go @@ -2,6 +2,7 @@ package encoding import ( "hash/fnv" + "v2ray.com/core/common/crypto" "v2ray.com/core/common/serial" ) diff --git a/proxy/vmess/encoding/auth_test.go b/proxy/vmess/encoding/auth_test.go new file mode 100644 index 000000000..9818c6733 --- /dev/null +++ b/proxy/vmess/encoding/auth_test.go @@ -0,0 +1,24 @@ +package encoding_test + +import ( + "crypto/rand" + "testing" + + . "v2ray.com/core/proxy/vmess/encoding" + "v2ray.com/core/testing/assert" +) + +func TestFnvAuth(t *testing.T) { + assert := assert.On(t) + fnvAuth := new(FnvAuthenticator) + + expectedText := make([]byte, 256) + rand.Read(expectedText) + + buffer := make([]byte, 512) + b := fnvAuth.Seal(buffer, nil, expectedText, nil) + b, err := fnvAuth.Open(buffer, nil, b, nil) + assert.Error(err).IsNil() + assert.Int(len(b)).Equals(256) + assert.Bytes(b).Equals(expectedText) +} diff --git a/proxy/vmess/encoding/client.go b/proxy/vmess/encoding/client.go index 160a1c153..2b0d34d76 100644 --- a/proxy/vmess/encoding/client.go +++ b/proxy/vmess/encoding/client.go @@ -1,15 +1,23 @@ package encoding import ( + "crypto/aes" + "crypto/cipher" "crypto/md5" "crypto/rand" "fmt" "hash/fnv" "io" + + "golang.org/x/crypto/chacha20poly1305" + "v2ray.com/core/common/crypto" + "v2ray.com/core/common/dice" + v2io "v2ray.com/core/common/io" "v2ray.com/core/common/log" v2net "v2ray.com/core/common/net" "v2ray.com/core/common/protocol" + "v2ray.com/core/common/serial" "v2ray.com/core/proxy/vmess" ) @@ -64,7 +72,10 @@ func (v *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ buffer = append(buffer, Version) buffer = append(buffer, v.requestBodyIV...) buffer = append(buffer, v.requestBodyKey...) - buffer = append(buffer, v.responseHeader, byte(header.Option), byte(0), byte(0), byte(header.Command)) + buffer = append(buffer, v.responseHeader, byte(header.Option)) + padingLen := dice.Roll(16) + security := byte(padingLen<<4) | byte(header.Security) + buffer = append(buffer, security, byte(0), byte(header.Command)) buffer = header.Port.Bytes(buffer) switch header.Address.Family() { @@ -79,6 +90,10 @@ func (v *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ buffer = append(buffer, header.Address.Domain()...) } + pading := make([]byte, padingLen) + rand.Read(pading) + buffer = append(buffer, pading...) + fnv1a := fnv.New32a() fnv1a.Write(buffer) @@ -94,9 +109,61 @@ func (v *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ return } -func (v *ClientSession) EncodeRequestBody(writer io.Writer) io.Writer { - aesStream := crypto.NewAesEncryptionStream(v.requestBodyKey, v.requestBodyIV) - return crypto.NewCryptionWriter(aesStream, writer) +func (v *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, writer io.Writer) v2io.Writer { + var authWriter io.Writer + if request.Security.Is(protocol.SecurityType_NONE) { + if request.Option.Has(protocol.RequestOptionChunkStream) { + auth := &crypto.AEADAuthenticator{ + AEAD: new(FnvAuthenticator), + NonceGenerator: crypto.NoOpBytesGenerator{}, + AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, + } + authWriter = crypto.NewAuthenticationWriter(auth, writer) + } else { + authWriter = writer + } + } else if request.Security.Is(protocol.SecurityType_LEGACY) { + aesStream := crypto.NewAesEncryptionStream(v.requestBodyKey, v.requestBodyIV) + cryptionWriter := crypto.NewCryptionWriter(aesStream, writer) + if request.Option.Has(protocol.RequestOptionChunkStream) { + auth := &crypto.AEADAuthenticator{ + AEAD: new(FnvAuthenticator), + NonceGenerator: crypto.NoOpBytesGenerator{}, + AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, + } + authWriter = crypto.NewAuthenticationWriter(auth, cryptionWriter) + } else { + authWriter = cryptionWriter + } + } else if request.Security.Is(protocol.SecurityType_AES128_GCM) { + block, _ := aes.NewCipher(v.responseBodyKey) + aead, _ := cipher.NewGCM(block) + + auth := &crypto.AEADAuthenticator{ + AEAD: aead, + NonceGenerator: &ChunkNonceGenerator{ + Nonce: append([]byte(nil), v.responseBodyIV...), + Size: aead.NonceSize(), + }, + AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, + } + authWriter = crypto.NewAuthenticationWriter(auth, writer) + } else if request.Security.Is(protocol.SecurityType_CHACHA20_POLY1305) { + aead, _ := chacha20poly1305.New(v.responseBodyKey) + + auth := &crypto.AEADAuthenticator{ + AEAD: aead, + NonceGenerator: &ChunkNonceGenerator{ + Nonce: append([]byte(nil), v.responseBodyIV...), + Size: aead.NonceSize(), + }, + AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, + } + authWriter = crypto.NewAuthenticationWriter(auth, writer) + } + + return v2io.NewAdaptiveWriter(authWriter) + } func (v *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.ResponseHeader, error) { @@ -107,7 +174,7 @@ func (v *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon _, err := io.ReadFull(v.responseReader, buffer[:4]) if err != nil { - log.Info("Raw: Failed to read response header: ", err) + log.Info("VMess|Client: Failed to read response header: ", err) return nil, err } @@ -124,7 +191,7 @@ func (v *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon dataLen := int(buffer[3]) _, err := io.ReadFull(v.responseReader, buffer[:dataLen]) if err != nil { - log.Info("Raw: Failed to read response command: ", err) + log.Info("VMess|Client: Failed to read response command: ", err) return nil, err } data := buffer[:dataLen] @@ -137,6 +204,69 @@ func (v *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon return header, nil } -func (v *ClientSession) DecodeResponseBody(reader io.Reader) io.Reader { - return v.responseReader +func (v *ClientSession) DecodeResponseBody(request *protocol.RequestHeader, reader io.Reader) v2io.Reader { + aggressive := (request.Command == protocol.RequestCommandTCP) + var authReader io.Reader + if request.Security.Is(protocol.SecurityType_NONE) { + if request.Option.Has(protocol.RequestOptionChunkStream) { + auth := &crypto.AEADAuthenticator{ + AEAD: new(FnvAuthenticator), + NonceGenerator: crypto.NoOpBytesGenerator{}, + AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, + } + authReader = crypto.NewAuthenticationReader(auth, reader, aggressive) + } else { + authReader = reader + } + } else if request.Security.Is(protocol.SecurityType_LEGACY) { + if request.Option.Has(protocol.RequestOptionChunkStream) { + auth := &crypto.AEADAuthenticator{ + AEAD: new(FnvAuthenticator), + NonceGenerator: crypto.NoOpBytesGenerator{}, + AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, + } + authReader = crypto.NewAuthenticationReader(auth, v.responseReader, aggressive) + } else { + authReader = v.responseReader + } + } else if request.Security.Is(protocol.SecurityType_AES128_GCM) { + block, _ := aes.NewCipher(v.responseBodyKey) + aead, _ := cipher.NewGCM(block) + + auth := &crypto.AEADAuthenticator{ + AEAD: aead, + NonceGenerator: &ChunkNonceGenerator{ + Nonce: append([]byte(nil), v.responseBodyIV...), + Size: aead.NonceSize(), + }, + AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, + } + authReader = crypto.NewAuthenticationReader(auth, reader, aggressive) + } else if request.Security.Is(protocol.SecurityType_CHACHA20_POLY1305) { + aead, _ := chacha20poly1305.New(v.responseBodyKey) + + auth := &crypto.AEADAuthenticator{ + AEAD: aead, + NonceGenerator: &ChunkNonceGenerator{ + Nonce: append([]byte(nil), v.responseBodyIV...), + Size: aead.NonceSize(), + }, + AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, + } + authReader = crypto.NewAuthenticationReader(auth, reader, aggressive) + } + + return v2io.NewAdaptiveReader(authReader) +} + +type ChunkNonceGenerator struct { + Nonce []byte + Size int + count uint16 +} + +func (v *ChunkNonceGenerator) Next() []byte { + serial.Uint16ToBytes(v.count, v.Nonce[:2]) + v.count++ + return v.Nonce[:v.Size] } diff --git a/proxy/vmess/encoding/encoding_test.go b/proxy/vmess/encoding/encoding_test.go index cd6ef93ab..48be15bbc 100644 --- a/proxy/vmess/encoding/encoding_test.go +++ b/proxy/vmess/encoding/encoding_test.go @@ -27,12 +27,13 @@ func TestRequestSerialization(t *testing.T) { user.Account = loader.NewTypedSettings(account) expectedRequest := &protocol.RequestHeader{ - Version: 1, - User: user, - Command: protocol.RequestCommandTCP, - Option: protocol.RequestOption(0), - Address: v2net.DomainAddress("www.v2ray.com"), - Port: v2net.Port(443), + Version: 1, + User: user, + Command: protocol.RequestCommandTCP, + Option: protocol.RequestOption(0), + Address: v2net.DomainAddress("www.v2ray.com"), + Port: v2net.Port(443), + Security: protocol.Security(protocol.SecurityType_AES128_GCM), } buffer := alloc.NewBuffer() @@ -51,4 +52,5 @@ func TestRequestSerialization(t *testing.T) { assert.Byte(byte(expectedRequest.Option)).Equals(byte(actualRequest.Option)) assert.Address(expectedRequest.Address).Equals(actualRequest.Address) assert.Port(expectedRequest.Port).Equals(actualRequest.Port) + assert.Byte(byte(expectedRequest.Security)).Equals(byte(actualRequest.Security)) } diff --git a/proxy/vmess/encoding/server.go b/proxy/vmess/encoding/server.go index 5197ae257..2c066acae 100644 --- a/proxy/vmess/encoding/server.go +++ b/proxy/vmess/encoding/server.go @@ -1,11 +1,17 @@ package encoding import ( + "crypto/aes" + "crypto/cipher" "crypto/md5" "hash/fnv" "io" + + "golang.org/x/crypto/chacha20poly1305" + "v2ray.com/core/common/crypto" "v2ray.com/core/common/errors" + v2io "v2ray.com/core/common/io" "v2ray.com/core/common/log" v2net "v2ray.com/core/common/net" "v2ray.com/core/common/protocol" @@ -84,7 +90,10 @@ func (v *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request v.requestBodyIV = append([]byte(nil), buffer[1:17]...) // 16 bytes v.requestBodyKey = append([]byte(nil), buffer[17:33]...) // 16 bytes v.responseHeader = buffer[33] // 1 byte - request.Option = protocol.RequestOption(buffer[34]) // 1 byte + 2 bytes reserved + request.Option = protocol.RequestOption(buffer[34]) // 1 byte + padingLen := int(buffer[35] >> 4) + request.Security = protocol.Security(buffer[35] & 0x0F) + // 1 bytes reserved request.Command = protocol.RequestCommand(buffer[37]) request.Port = v2net.PortFromBytes(buffer[38:40]) @@ -121,6 +130,14 @@ func (v *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request request.Address = v2net.DomainAddress(string(buffer[42 : 42+domainLength])) } + if padingLen > 0 { + _, err = io.ReadFull(decryptor, buffer[bufferLen:bufferLen+padingLen]) + if err != nil { + return nil, errors.New("VMess|Server: Failed to read padding.") + } + bufferLen += padingLen + } + _, err = io.ReadFull(decryptor, buffer[bufferLen:bufferLen+4]) if err != nil { return nil, errors.Base(err).Message("VMess|Server: Failed to read checksum.") @@ -138,9 +155,61 @@ func (v *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request return request, nil } -func (v *ServerSession) DecodeRequestBody(reader io.Reader) io.Reader { - aesStream := crypto.NewAesDecryptionStream(v.requestBodyKey, v.requestBodyIV) - return crypto.NewCryptionReader(aesStream, reader) +func (v *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reader io.Reader) v2io.Reader { + aggressive := (request.Command == protocol.RequestCommandTCP) + var authReader io.Reader + if request.Security.Is(protocol.SecurityType_NONE) { + if request.Option.Has(protocol.RequestOptionChunkStream) { + auth := &crypto.AEADAuthenticator{ + AEAD: new(FnvAuthenticator), + NonceGenerator: crypto.NoOpBytesGenerator{}, + AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, + } + authReader = crypto.NewAuthenticationReader(auth, reader, aggressive) + } else { + authReader = reader + } + } else if request.Security.Is(protocol.SecurityType_LEGACY) { + aesStream := crypto.NewAesDecryptionStream(v.requestBodyKey, v.requestBodyIV) + cryptionReader := crypto.NewCryptionReader(aesStream, reader) + if request.Option.Has(protocol.RequestOptionChunkStream) { + auth := &crypto.AEADAuthenticator{ + AEAD: new(FnvAuthenticator), + NonceGenerator: crypto.NoOpBytesGenerator{}, + AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, + } + authReader = crypto.NewAuthenticationReader(auth, cryptionReader, aggressive) + } else { + authReader = cryptionReader + } + } else if request.Security.Is(protocol.SecurityType_AES128_GCM) { + block, _ := aes.NewCipher(v.responseBodyKey) + aead, _ := cipher.NewGCM(block) + + auth := &crypto.AEADAuthenticator{ + AEAD: aead, + NonceGenerator: &ChunkNonceGenerator{ + Nonce: append([]byte(nil), v.responseBodyIV...), + Size: aead.NonceSize(), + }, + AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, + } + authReader = crypto.NewAuthenticationReader(auth, reader, aggressive) + } else if request.Security.Is(protocol.SecurityType_CHACHA20_POLY1305) { + aead, _ := chacha20poly1305.New(v.responseBodyKey) + + auth := &crypto.AEADAuthenticator{ + AEAD: aead, + NonceGenerator: &ChunkNonceGenerator{ + Nonce: append([]byte(nil), v.responseBodyIV...), + Size: aead.NonceSize(), + }, + AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, + } + authReader = crypto.NewAuthenticationReader(auth, reader, aggressive) + } + + return v2io.NewAdaptiveReader(authReader) } func (v *ServerSession) EncodeResponseHeader(header *protocol.ResponseHeader, writer io.Writer) { @@ -160,6 +229,56 @@ func (v *ServerSession) EncodeResponseHeader(header *protocol.ResponseHeader, wr } } -func (v *ServerSession) EncodeResponseBody(writer io.Writer) io.Writer { - return v.responseWriter +func (v *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writer io.Writer) v2io.Writer { + var authWriter io.Writer + if request.Security.Is(protocol.SecurityType_NONE) { + if request.Option.Has(protocol.RequestOptionChunkStream) { + auth := &crypto.AEADAuthenticator{ + AEAD: new(FnvAuthenticator), + NonceGenerator: crypto.NoOpBytesGenerator{}, + AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, + } + authWriter = crypto.NewAuthenticationWriter(auth, writer) + } else { + authWriter = writer + } + } else if request.Security.Is(protocol.SecurityType_LEGACY) { + if request.Option.Has(protocol.RequestOptionChunkStream) { + auth := &crypto.AEADAuthenticator{ + AEAD: new(FnvAuthenticator), + NonceGenerator: crypto.NoOpBytesGenerator{}, + AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, + } + authWriter = crypto.NewAuthenticationWriter(auth, v.responseWriter) + } else { + authWriter = v.responseWriter + } + } else if request.Security.Is(protocol.SecurityType_AES128_GCM) { + block, _ := aes.NewCipher(v.responseBodyKey) + aead, _ := cipher.NewGCM(block) + + auth := &crypto.AEADAuthenticator{ + AEAD: aead, + NonceGenerator: &ChunkNonceGenerator{ + Nonce: append([]byte(nil), v.responseBodyIV...), + Size: aead.NonceSize(), + }, + AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, + } + authWriter = crypto.NewAuthenticationWriter(auth, writer) + } else if request.Security.Is(protocol.SecurityType_CHACHA20_POLY1305) { + aead, _ := chacha20poly1305.New(v.responseBodyKey) + + auth := &crypto.AEADAuthenticator{ + AEAD: aead, + NonceGenerator: &ChunkNonceGenerator{ + Nonce: append([]byte(nil), v.responseBodyIV...), + Size: aead.NonceSize(), + }, + AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, + } + authWriter = crypto.NewAuthenticationWriter(auth, writer) + } + + return v2io.NewAdaptiveWriter(authWriter) } diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index 968cbbd1a..4f7fdca89 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -9,7 +9,6 @@ import ( "v2ray.com/core/app/proxyman" "v2ray.com/core/common" "v2ray.com/core/common/alloc" - "v2ray.com/core/common/crypto" "v2ray.com/core/common/errors" v2io "v2ray.com/core/common/io" "v2ray.com/core/common/loader" @@ -21,7 +20,6 @@ import ( "v2ray.com/core/proxy/registry" "v2ray.com/core/proxy/vmess" "v2ray.com/core/proxy/vmess/encoding" - vmessio "v2ray.com/core/proxy/vmess/io" "v2ray.com/core/transport/internet" ) @@ -187,24 +185,12 @@ func (v *VMessInboundHandler) HandleConnection(connection internet.Connection) { reader.SetCached(false) go func() { - bodyReader := session.DecodeRequestBody(reader) - var requestReader v2io.Reader - if request.Option.Has(protocol.RequestOptionChunkStream) { - auth := &crypto.AEADAuthenticator{ - AEAD: new(encoding.FnvAuthenticator), - NonceGenerator: crypto.NoOpBytesGenerator{}, - AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, - } - authReader := crypto.NewAuthenticationReader(auth, bodyReader, request.Command == protocol.RequestCommandTCP) - requestReader = v2io.NewAdaptiveReader(authReader) - } else { - requestReader = v2io.NewAdaptiveReader(bodyReader) - } - if err := v2io.PipeUntilEOF(requestReader, input); err != nil { + bodyReader := session.DecodeRequestBody(request, reader) + if err := v2io.PipeUntilEOF(bodyReader, input); err != nil { connection.SetReusable(false) } + bodyReader.Release() - requestReader.Release() input.Close() readFinish.Unlock() }() @@ -222,33 +208,29 @@ func (v *VMessInboundHandler) HandleConnection(connection internet.Connection) { session.EncodeResponseHeader(response, writer) - bodyWriter := session.EncodeResponseBody(writer) - var v2writer v2io.Writer = v2io.NewAdaptiveWriter(bodyWriter) - if request.Option.Has(protocol.RequestOptionChunkStream) { - v2writer = vmessio.NewAuthChunkWriter(v2writer) - } + bodyWriter := session.EncodeResponseBody(request, writer) // Optimize for small response packet if data, err := output.Read(); err == nil { - if err := v2writer.Write(data); err != nil { + if err := bodyWriter.Write(data); err != nil { connection.SetReusable(false) } writer.SetCached(false) - if err := v2io.PipeUntilEOF(output, v2writer); err != nil { + if err := v2io.PipeUntilEOF(output, bodyWriter); err != nil { connection.SetReusable(false) } } output.Release() if request.Option.Has(protocol.RequestOptionChunkStream) { - if err := v2writer.Write(alloc.NewLocalBuffer(32)); err != nil { + if err := bodyWriter.Write(alloc.NewLocalBuffer(32)); err != nil { connection.SetReusable(false) } } writer.Flush() - v2writer.Release() + bodyWriter.Release() readFinish.Lock() } diff --git a/proxy/vmess/io/io_test.go b/proxy/vmess/io/io_test.go deleted file mode 100644 index 16dffec0c..000000000 --- a/proxy/vmess/io/io_test.go +++ /dev/null @@ -1,88 +0,0 @@ -package io_test - -import ( - "bytes" - "crypto/rand" - "io" - "testing" - - "v2ray.com/core/common/alloc" - "v2ray.com/core/common/errors" - v2io "v2ray.com/core/common/io" - "v2ray.com/core/common/serial" - . "v2ray.com/core/proxy/vmess/io" - "v2ray.com/core/testing/assert" -) - -func TestAuthenticate(t *testing.T) { - assert := assert.On(t) - - buffer := alloc.NewBuffer() - buffer.AppendBytes(1, 2, 3, 4) - Authenticate(buffer) - assert.Bytes(buffer.Bytes()).Equals([]byte{0, 8, 87, 52, 168, 125, 1, 2, 3, 4}) - - b2, err := NewAuthChunkReader(buffer).Read() - assert.Error(err).IsNil() - assert.Bytes(b2.Bytes()).Equals([]byte{1, 2, 3, 4}) -} - -func TestSingleIO(t *testing.T) { - assert := assert.On(t) - - content := bytes.NewBuffer(make([]byte, 0, 1024*1024)) - - writer := NewAuthChunkWriter(v2io.NewAdaptiveWriter(content)) - b := alloc.NewBuffer() - b.AppendFunc(serial.WriteString("abcd")) - writer.Write(b) - writer.Write(alloc.NewBuffer()) - writer.Release() - - reader := NewAuthChunkReader(content) - buffer, err := reader.Read() - assert.Error(err).IsNil() - assert.String(buffer.String()).Equals("abcd") -} - -func TestLargeIO(t *testing.T) { - assert := assert.On(t) - - content := make([]byte, 1024*1024) - rand.Read(content) - - chunckContent := bytes.NewBuffer(make([]byte, 0, len(content)*2)) - writer := NewAuthChunkWriter(v2io.NewAdaptiveWriter(chunckContent)) - writeSize := 0 - for { - chunkSize := 7 * 1024 - if chunkSize+writeSize > len(content) { - chunkSize = len(content) - writeSize - } - b := alloc.NewBuffer() - b.Append(content[writeSize : writeSize+chunkSize]) - writer.Write(b) - b.Release() - - writeSize += chunkSize - if writeSize == len(content) { - break - } - } - writer.Write(alloc.NewBuffer()) - writer.Release() - - actualContent := make([]byte, 0, len(content)) - reader := NewAuthChunkReader(chunckContent) - for { - buffer, err := reader.Read() - if errors.Cause(err) == io.EOF { - break - } - assert.Error(err).IsNil() - actualContent = append(actualContent, buffer.Bytes()...) - } - - assert.Int(len(actualContent)).Equals(len(content)) - assert.Bytes(actualContent).Equals(content) -} diff --git a/proxy/vmess/io/reader.go b/proxy/vmess/io/reader.go deleted file mode 100644 index f1b5e9c4c..000000000 --- a/proxy/vmess/io/reader.go +++ /dev/null @@ -1,116 +0,0 @@ -package io - -import ( - "hash" - "hash/fnv" - "io" - "v2ray.com/core/common/alloc" - "v2ray.com/core/common/errors" - "v2ray.com/core/common/serial" -) - -// Private: Visible for testing. -type Validator struct { - actualAuth hash.Hash32 - expectedAuth uint32 -} - -func NewValidator(expectedAuth uint32) *Validator { - return &Validator{ - actualAuth: fnv.New32a(), - expectedAuth: expectedAuth, - } -} - -func (v *Validator) Consume(b []byte) { - v.actualAuth.Write(b) -} - -func (v *Validator) Validate() bool { - return v.actualAuth.Sum32() == v.expectedAuth -} - -type AuthChunkReader struct { - reader io.Reader - last *alloc.Buffer - chunkLength int - validator *Validator -} - -func NewAuthChunkReader(reader io.Reader) *AuthChunkReader { - return &AuthChunkReader{ - reader: reader, - chunkLength: -1, - } -} - -func (v *AuthChunkReader) Read() (*alloc.Buffer, error) { - var buffer *alloc.Buffer - if v.last != nil { - buffer = v.last - v.last = nil - } else { - buffer = alloc.NewBuffer() - } - - if v.chunkLength == -1 { - for buffer.Len() < 6 { - _, err := buffer.FillFrom(v.reader) - if err != nil { - buffer.Release() - return nil, io.ErrUnexpectedEOF - } - } - length := serial.BytesToUint16(buffer.BytesTo(2)) - v.chunkLength = int(length) - 4 - v.validator = NewValidator(serial.BytesToUint32(buffer.BytesRange(2, 6))) - buffer.SliceFrom(6) - if buffer.Len() < v.chunkLength && v.chunkLength <= 2048 { - _, err := buffer.FillFrom(v.reader) - if err != nil { - buffer.Release() - return nil, io.ErrUnexpectedEOF - } - } - } else if buffer.Len() < v.chunkLength { - _, err := buffer.FillFrom(v.reader) - if err != nil { - buffer.Release() - return nil, io.ErrUnexpectedEOF - } - } - - if v.chunkLength == 0 { - buffer.Release() - return nil, io.EOF - } - - if buffer.Len() < v.chunkLength { - v.validator.Consume(buffer.Bytes()) - v.chunkLength -= buffer.Len() - } else { - v.validator.Consume(buffer.BytesTo(v.chunkLength)) - if !v.validator.Validate() { - buffer.Release() - return nil, errors.New("VMess|AuthChunkReader: Invalid auth.") - } - leftLength := buffer.Len() - v.chunkLength - if leftLength > 0 { - v.last = alloc.NewBuffer() - v.last.Append(buffer.BytesFrom(v.chunkLength)) - buffer.Slice(0, v.chunkLength) - } - - v.chunkLength = -1 - v.validator = nil - } - - return buffer, nil -} - -func (v *AuthChunkReader) Release() { - v.reader = nil - v.last.Release() - v.last = nil - v.validator = nil -} diff --git a/proxy/vmess/io/writer.go b/proxy/vmess/io/writer.go deleted file mode 100644 index 5c6ce82af..000000000 --- a/proxy/vmess/io/writer.go +++ /dev/null @@ -1,37 +0,0 @@ -package io - -import ( - "hash/fnv" - - "v2ray.com/core/common/alloc" - v2io "v2ray.com/core/common/io" - "v2ray.com/core/common/serial" -) - -type AuthChunkWriter struct { - writer v2io.Writer -} - -func NewAuthChunkWriter(writer v2io.Writer) *AuthChunkWriter { - return &AuthChunkWriter{ - writer: writer, - } -} - -func (v *AuthChunkWriter) Write(buffer *alloc.Buffer) error { - Authenticate(buffer) - return v.writer.Write(buffer) -} - -func (v *AuthChunkWriter) Release() { - v.writer.Release() - v.writer = nil -} - -func Authenticate(buffer *alloc.Buffer) { - fnvHash := fnv.New32a() - fnvHash.Write(buffer.Bytes()) - buffer.PrependFunc(4, serial.WriteHash(fnvHash)) - - buffer.PrependFunc(2, serial.WriteUint16(uint16(buffer.Len()))) -} diff --git a/proxy/vmess/outbound/outbound.go b/proxy/vmess/outbound/outbound.go index 923d31df8..90683185a 100644 --- a/proxy/vmess/outbound/outbound.go +++ b/proxy/vmess/outbound/outbound.go @@ -14,7 +14,6 @@ import ( "v2ray.com/core/proxy" "v2ray.com/core/proxy/registry" "v2ray.com/core/proxy/vmess/encoding" - vmessio "v2ray.com/core/proxy/vmess/io" "v2ray.com/core/transport/internet" "v2ray.com/core/transport/ray" ) @@ -92,13 +91,11 @@ func (v *VMessOutboundHandler) handleRequest(session *encoding.ClientSession, co defer writer.Release() session.EncodeRequestHeader(request, writer) - bodyWriter := session.EncodeRequestBody(writer) - var streamWriter v2io.Writer = v2io.NewAdaptiveWriter(bodyWriter) - if request.Option.Has(protocol.RequestOptionChunkStream) { - streamWriter = vmessio.NewAuthChunkWriter(streamWriter) - } + bodyWriter := session.EncodeRequestBody(request, writer) + defer bodyWriter.Release() + if !payload.IsEmpty() { - if err := streamWriter.Write(payload); err != nil { + if err := bodyWriter.Write(payload); err != nil { log.Info("VMess|Outbound: Failed to write payload. Disabling connection reuse.", err) conn.SetReusable(false) } @@ -106,17 +103,16 @@ func (v *VMessOutboundHandler) handleRequest(session *encoding.ClientSession, co } writer.SetCached(false) - if err := v2io.PipeUntilEOF(input, streamWriter); err != nil { + if err := v2io.PipeUntilEOF(input, bodyWriter); err != nil { conn.SetReusable(false) } if request.Option.Has(protocol.RequestOptionChunkStream) { - err := streamWriter.Write(alloc.NewLocalBuffer(32)) + err := bodyWriter.Write(alloc.NewLocalBuffer(32)) if err != nil { conn.SetReusable(false) } } - streamWriter.Release() return } @@ -139,20 +135,13 @@ func (v *VMessOutboundHandler) handleResponse(session *encoding.ClientSession, c } reader.SetCached(false) - decryptReader := session.DecodeResponseBody(reader) - - var bodyReader v2io.Reader - if request.Option.Has(protocol.RequestOptionChunkStream) { - bodyReader = vmessio.NewAuthChunkReader(decryptReader) - } else { - bodyReader = v2io.NewAdaptiveReader(decryptReader) - } + bodyReader := session.DecodeResponseBody(request, reader) + defer bodyReader.Release() if err := v2io.PipeUntilEOF(bodyReader, output); err != nil { conn.SetReusable(false) } - bodyReader.Release() return }