diff --git a/proxy/vmess/io/io_test.go b/proxy/vmess/io/io_test.go new file mode 100644 index 000000000..c271e2867 --- /dev/null +++ b/proxy/vmess/io/io_test.go @@ -0,0 +1,109 @@ +package io_test + +import ( + "bytes" + "crypto/rand" + "io" + "testing" + + "github.com/v2ray/v2ray-core/common/alloc" + v2io "github.com/v2ray/v2ray-core/common/io" + . "github.com/v2ray/v2ray-core/proxy/vmess/io" + v2testing "github.com/v2ray/v2ray-core/testing" + "github.com/v2ray/v2ray-core/testing/assert" +) + +func TestAuthenticate(t *testing.T) { + v2testing.Current(t) + + buffer := alloc.NewBuffer().Clear() + buffer.AppendBytes(1, 2, 3, 4) + Authenticate(buffer) + assert.Bytes(buffer.Value).Equals([]byte{0, 8, 87, 52, 168, 125, 1, 2, 3, 4}) + + b2, err := NewAuthChunkReader(buffer).Read() + assert.Error(err).IsNil() + assert.Bytes(b2.Value).Equals([]byte{1, 2, 3, 4}) +} + +func TestSingleIO(t *testing.T) { + v2testing.Current(t) + + content := bytes.NewBuffer(make([]byte, 0, 1024*1024)) + + writer := NewAuthChunkWriter(v2io.NewAdaptiveWriter(content)) + writer.Write(alloc.NewBuffer().Clear().AppendString("abcd")) + writer.Release() + + reader := NewAuthChunkReader(content) + buffer, err := reader.Read() + assert.Error(err).IsNil() + assert.Bytes(buffer.Value).Equals([]byte("abcd")) +} + +func TestLargeIO(t *testing.T) { + v2testing.Current(t) + + content := make([]byte, 1024*1024) + rand.Read(content) + + chunckContent := bytes.NewBuffer(make([]byte, 0, len(content)*2)) + writer := NewAuthChunkWriter(v2io.NewAdaptiveWriter(chunckContent)) + writeSize := 0 + for { + chunkSize := 7 * 1024 + if chunkSize+writeSize > len(content) { + chunkSize = len(content) - writeSize + } + writer.Write(alloc.NewBuffer().Clear().Append(content[writeSize : writeSize+chunkSize])) + writeSize += chunkSize + if writeSize == len(content) { + break + } + + chunkSize = 8 * 1024 + if chunkSize+writeSize > len(content) { + chunkSize = len(content) - writeSize + } + writer.Write(alloc.NewLargeBuffer().Clear().Append(content[writeSize : writeSize+chunkSize])) + writeSize += chunkSize + if writeSize == len(content) { + break + } + + chunkSize = 63 * 1024 + if chunkSize+writeSize > len(content) { + chunkSize = len(content) - writeSize + } + writer.Write(alloc.NewLargeBuffer().Clear().Append(content[writeSize : writeSize+chunkSize])) + writeSize += chunkSize + if writeSize == len(content) { + break + } + + chunkSize = 64*1024 - 16 + if chunkSize+writeSize > len(content) { + chunkSize = len(content) - writeSize + } + writer.Write(alloc.NewLargeBuffer().Clear().Append(content[writeSize : writeSize+chunkSize])) + writeSize += chunkSize + if writeSize == len(content) { + break + } + } + writer.Release() + + actualContent := make([]byte, 0, len(content)) + reader := NewAuthChunkReader(chunckContent) + for { + buffer, err := reader.Read() + if err == io.EOF { + break + } + assert.Error(err).IsNil() + actualContent = append(actualContent, buffer.Value...) + } + + assert.Int(len(actualContent)).Equals(len(content)) + assert.Bytes(actualContent).Equals(content) +} diff --git a/proxy/vmess/io/reader.go b/proxy/vmess/io/reader.go index 5dab18ff1..a9debadf9 100644 --- a/proxy/vmess/io/reader.go +++ b/proxy/vmess/io/reader.go @@ -1,6 +1,7 @@ package io import ( + "hash" "hash/fnv" "io" @@ -9,49 +10,115 @@ import ( "github.com/v2ray/v2ray-core/transport" ) +// @Private +func AllocBuffer(size int) *alloc.Buffer { + if size < 8*1024-16 { + return alloc.NewBuffer() + } + return alloc.NewLargeBuffer() +} + +// @Private +type Validator struct { + actualAuth hash.Hash32 + expectedAuth uint32 +} + +func NewValidator(expectedAuth uint32) *Validator { + return &Validator{ + actualAuth: fnv.New32a(), + expectedAuth: expectedAuth, + } +} + +func (this *Validator) Consume(b []byte) { + this.actualAuth.Write(b) +} + +func (this *Validator) Validate() bool { + return this.actualAuth.Sum32() == this.expectedAuth +} + type AuthChunkReader struct { - reader io.Reader + reader io.Reader + last *alloc.Buffer + chunkLength int + validator *Validator } func NewAuthChunkReader(reader io.Reader) *AuthChunkReader { return &AuthChunkReader{ - reader: reader, + reader: reader, + chunkLength: -1, } } func (this *AuthChunkReader) Read() (*alloc.Buffer, error) { - buffer := alloc.NewBuffer() - if _, err := io.ReadFull(this.reader, buffer.Value[:2]); err != nil { + var buffer *alloc.Buffer + if this.last != nil { + buffer = this.last + this.last = nil + } else { + buffer = AllocBuffer(this.chunkLength).Clear() + } + + _, err := buffer.FillFrom(this.reader) + if err != nil { buffer.Release() return nil, err } - length := serial.BytesLiteral(buffer.Value[:2]).Uint16Value() - if length <= 4 { // Length of authentication bytes. + if this.chunkLength == -1 { + for buffer.Len() < 6 { + _, err := buffer.FillFrom(this.reader) + if err != nil { + buffer.Release() + return nil, err + } + } + length := serial.BytesLiteral(buffer.Value[:2]).Uint16Value() + this.chunkLength = int(length) - 4 + this.validator = NewValidator(serial.BytesLiteral(buffer.Value[2:6]).Uint32Value()) + buffer.SliceFrom(6) + } + + if this.chunkLength == 0 { + buffer.Release() return nil, io.EOF } - if length > 8*1024-16 { - buffer.Release() - buffer = alloc.NewLargeBuffer() - } - if _, err := io.ReadFull(this.reader, buffer.Value[:length]); err != nil { - buffer.Release() - return nil, err - } - buffer.Slice(0, int(length)) - fnvHash := fnv.New32a() - fnvHash.Write(buffer.Value[4:]) - expAuth := serial.BytesLiteral(fnvHash.Sum(nil)) - actualAuth := serial.BytesLiteral(buffer.Value[:4]) - if !actualAuth.Equals(expAuth) { - buffer.Release() - return nil, transport.ErrorCorruptedPacket + if buffer.Len() <= this.chunkLength { + this.validator.Consume(buffer.Value) + this.chunkLength -= buffer.Len() + if this.chunkLength == 0 { + if !this.validator.Validate() { + buffer.Release() + return nil, transport.ErrorCorruptedPacket + } + this.chunkLength = -1 + this.validator = nil + } + } else { + this.validator.Consume(buffer.Value[:this.chunkLength]) + if !this.validator.Validate() { + buffer.Release() + return nil, transport.ErrorCorruptedPacket + } + leftLength := buffer.Len() - this.chunkLength + this.last = AllocBuffer(leftLength).Clear() + this.last.Append(buffer.Value[this.chunkLength:]) + buffer.Slice(0, this.chunkLength) + + this.chunkLength = -1 + this.validator = nil } - buffer.SliceFrom(4) + return buffer, nil } func (this *AuthChunkReader) Release() { this.reader = nil + this.last.Release() + this.last = nil + this.validator = nil } diff --git a/proxy/vmess/io/writer_test.go b/proxy/vmess/io/writer_test.go deleted file mode 100644 index e5cc0af9f..000000000 --- a/proxy/vmess/io/writer_test.go +++ /dev/null @@ -1,23 +0,0 @@ -package io_test - -import ( - "testing" - - "github.com/v2ray/v2ray-core/common/alloc" - . "github.com/v2ray/v2ray-core/proxy/vmess/io" - v2testing "github.com/v2ray/v2ray-core/testing" - "github.com/v2ray/v2ray-core/testing/assert" -) - -func TestAuthenticate(t *testing.T) { - v2testing.Current(t) - - buffer := alloc.NewBuffer().Clear() - buffer.AppendBytes(1, 2, 3, 4) - Authenticate(buffer) - assert.Bytes(buffer.Value).Equals([]byte{0, 8, 87, 52, 168, 125, 1, 2, 3, 4}) - - b2, err := NewAuthChunkReader(buffer).Read() - assert.Error(err).IsNil() - assert.Bytes(b2.Value).Equals([]byte{1, 2, 3, 4}) -}