diff --git a/common/buf/writer.go b/common/buf/writer.go index 16149d617..262879fd8 100644 --- a/common/buf/writer.go +++ b/common/buf/writer.go @@ -81,7 +81,7 @@ func (w *bytesToBufferWriter) Write(payload []byte) (int, error) { return len(payload), nil } -func (w *bytesToBufferWriter) WriteMulteBuffer(mb MultiBuffer) (int, error) { +func (w *bytesToBufferWriter) WriteMultiBuffer(mb MultiBuffer) (int, error) { return mb.Len(), w.writer.Write(mb) } diff --git a/common/crypto/auth.go b/common/crypto/auth.go index 0092e7737..aa68fe422 100644 --- a/common/crypto/auth.go +++ b/common/crypto/auth.go @@ -4,15 +4,7 @@ import ( "crypto/cipher" "io" - "golang.org/x/crypto/sha3" - - "v2ray.com/core/common" "v2ray.com/core/common/buf" - "v2ray.com/core/common/serial" -) - -var ( - errInsufficientBuffer = newError("insufficient buffer") ) type BytesGenerator interface { @@ -68,198 +60,111 @@ func (v *AEADAuthenticator) Seal(dst, plainText []byte) ([]byte, error) { return v.AEAD.Seal(dst, iv, plainText, additionalData), nil } -type Uint16Generator interface { - Next() uint16 -} - -type StaticUint16Generator uint16 - -func (g StaticUint16Generator) Next() uint16 { - return uint16(g) -} - -type ShakeUint16Generator struct { - shake sha3.ShakeHash - buffer [2]byte -} - -func NewShakeUint16Generator(nonce []byte) *ShakeUint16Generator { - shake := sha3.NewShake128() - shake.Write(nonce) - return &ShakeUint16Generator{ - shake: shake, - } -} - -func (g *ShakeUint16Generator) Next() uint16 { - g.shake.Read(g.buffer[:]) - return serial.BytesToUint16(g.buffer[:]) -} - type AuthenticationReader struct { - auth Authenticator - buffer *buf.Buffer - reader io.Reader - sizeMask Uint16Generator - - chunk []byte + auth Authenticator + buffer *buf.Buffer + reader io.Reader + sizeParser ChunkSizeDecoder } const ( readerBufferSize = 32 * 1024 ) -func NewAuthenticationReader(auth Authenticator, reader io.Reader, sizeMask Uint16Generator) *AuthenticationReader { +func NewAuthenticationReader(auth Authenticator, sizeParser ChunkSizeDecoder, reader io.Reader) *AuthenticationReader { return &AuthenticationReader{ - auth: auth, - buffer: buf.NewLocal(readerBufferSize), - reader: reader, - sizeMask: sizeMask, + auth: auth, + buffer: buf.NewLocal(readerBufferSize), + reader: reader, + sizeParser: sizeParser, } } -func (v *AuthenticationReader) nextChunk(mask uint16) error { - if v.buffer.Len() < 2 { - return errInsufficientBuffer +func (r *AuthenticationReader) readChunk() error { + if err := r.buffer.Reset(buf.ReadFullFrom(r.reader, r.sizeParser.SizeBytes())); err != nil { + return err } - size := int(serial.BytesToUint16(v.buffer.BytesTo(2)) ^ mask) - if size > v.buffer.Len()-2 { - return errInsufficientBuffer - } - if size > readerBufferSize-2 { - return newError("size too large: ", size) - } - if size == v.auth.Overhead() { - return io.EOF - } - if size < v.auth.Overhead() { - return newError("invalid packet size: ", size) - } - cipherChunk := v.buffer.BytesRange(2, size+2) - plainChunk, err := v.auth.Open(cipherChunk[:0], cipherChunk) + size, err := r.sizeParser.Decode(r.buffer.Bytes()) if err != nil { return err } - v.chunk = plainChunk - v.buffer.SliceFrom(size + 2) + if size > readerBufferSize { + return newError("size too large ", size).AtWarning() + } + + if int(size) == r.auth.Overhead() { + return io.EOF + } + + if err := r.buffer.Reset(buf.ReadFullFrom(r.reader, int(size))); err != nil { + return err + } + + b, err := r.auth.Open(r.buffer.BytesTo(0), r.buffer.Bytes()) + if err != nil { + return err + } + r.buffer.Slice(0, len(b)) return nil } -func (v *AuthenticationReader) copyChunk(b []byte) int { - if len(v.chunk) == 0 { - return 0 - } - nBytes := copy(b, v.chunk) - if nBytes == len(v.chunk) { - v.chunk = nil - } else { - v.chunk = v.chunk[nBytes:] - } - return nBytes -} - -func (v *AuthenticationReader) ensureChunk() error { - atHead := false - if v.buffer.IsEmpty() { - v.buffer.Clear() - atHead = true - } - - mask := v.sizeMask.Next() - for { - err := v.nextChunk(mask) - if err != errInsufficientBuffer { - return err +func (r *AuthenticationReader) Read() (buf.MultiBuffer, error) { + if r.buffer.IsEmpty() { + if err := r.readChunk(); err != nil { + return nil, err } - - leftover := v.buffer.Bytes() - if !atHead && len(leftover) > 0 { - common.Must(v.buffer.Reset(func(b []byte) (int, error) { - return copy(b, leftover), nil - })) - } - - if err := v.buffer.AppendSupplier(buf.ReadFrom(v.reader)); err != nil { - return err - } - } -} - -func (v *AuthenticationReader) Read(b []byte) (int, error) { - if len(v.chunk) > 0 { - nBytes := v.copyChunk(b) - return nBytes, nil - } - - err := v.ensureChunk() - if err != nil { - return 0, err - } - - return v.copyChunk(b), nil -} - -func (r *AuthenticationReader) ReadMultiBuffer() (buf.MultiBuffer, error) { - err := r.ensureChunk() - if err != nil { - return nil, err } mb := buf.NewMultiBuffer() - for len(r.chunk) > 0 { + for !r.buffer.IsEmpty() { b := buf.New() - nBytes, _ := b.Write(r.chunk) + b.AppendSupplier(buf.ReadFrom(r.buffer)) mb.Append(b) - r.chunk = r.chunk[nBytes:] } - r.chunk = nil return mb, nil } type AuthenticationWriter struct { - auth Authenticator - buffer []byte - writer io.Writer - sizeMask Uint16Generator + auth Authenticator + buffer []byte + writer io.Writer + sizeParser ChunkSizeEncoder } -func NewAuthenticationWriter(auth Authenticator, writer io.Writer, sizeMask Uint16Generator) *AuthenticationWriter { +func NewAuthenticationWriter(auth Authenticator, sizeParser ChunkSizeEncoder, writer io.Writer) *AuthenticationWriter { return &AuthenticationWriter{ - auth: auth, - buffer: make([]byte, 32*1024), - writer: writer, - sizeMask: sizeMask, + auth: auth, + buffer: make([]byte, 32*1024), + writer: writer, + sizeParser: sizeParser, } } -// Write implements io.Writer. -func (w *AuthenticationWriter) Write(b []byte) (int, error) { - cipherChunk, err := w.auth.Seal(w.buffer[2:2], b) +func (w *AuthenticationWriter) writeInternal(b []byte) error { + sizeBytes := w.sizeParser.SizeBytes() + cipherChunk, err := w.auth.Seal(w.buffer[sizeBytes:sizeBytes], b) if err != nil { - return 0, err + return err } - size := uint16(len(cipherChunk)) ^ w.sizeMask.Next() - serial.Uint16ToBytes(size, w.buffer[:0]) - _, err = w.writer.Write(w.buffer[:2+len(cipherChunk)]) - return len(b), err + w.sizeParser.Encode(uint16(len(cipherChunk)), w.buffer[:0]) + _, err = w.writer.Write(w.buffer[:sizeBytes+len(cipherChunk)]) + return err } -func (w *AuthenticationWriter) WriteMultiBuffer(mb buf.MultiBuffer) (int, error) { +func (w *AuthenticationWriter) Write(mb buf.MultiBuffer) error { defer mb.Release() const StartIndex = 17 * 1024 - var totalBytes int for { payloadLen, _ := mb.Read(w.buffer[StartIndex:]) - nBytes, err := w.Write(w.buffer[StartIndex : StartIndex+payloadLen]) - totalBytes += nBytes + err := w.writeInternal(w.buffer[StartIndex : StartIndex+payloadLen]) if err != nil { - return totalBytes, err + return err } if mb.IsEmpty() { break } } - return totalBytes, nil + return nil } diff --git a/common/crypto/auth_test.go b/common/crypto/auth_test.go index f766a52c7..9992b4016 100644 --- a/common/crypto/auth_test.go +++ b/common/crypto/auth_test.go @@ -6,7 +6,6 @@ import ( "crypto/rand" "io" "testing" - "time" "v2ray.com/core/common/buf" . "v2ray.com/core/common/crypto" @@ -24,8 +23,11 @@ func TestAuthenticationReaderWriter(t *testing.T) { aead, err := cipher.NewGCM(block) assert.Error(err).IsNil() - payload := make([]byte, 8*1024) - rand.Read(payload) + rawPayload := make([]byte, 8192) + rand.Read(rawPayload) + + payload := buf.NewLocal(8192) + payload.Append(rawPayload) cache := buf.NewLocal(16 * 1024) iv := make([]byte, 12) @@ -37,13 +39,11 @@ func TestAuthenticationReaderWriter(t *testing.T) { Content: iv, }, AdditionalDataGenerator: &NoOpBytesGenerator{}, - }, cache, NewShakeUint16Generator([]byte{'a'})) + }, PlainChunkSizeParser{}, cache) - nBytes, err := writer.Write(payload) - assert.Error(err).IsNil() - assert.Int(nBytes).Equals(len(payload)) - assert.Int(cache.Len()).GreaterThan(0) - _, err = writer.Write([]byte{}) + assert.Error(writer.Write(buf.NewMultiBufferValue(payload))).IsNil() + assert.Int(cache.Len()).Equals(8210) + assert.Error(writer.Write(buf.NewMultiBuffer())).IsNil() assert.Error(err).IsNil() reader := NewAuthenticationReader(&AEADAuthenticator{ @@ -52,90 +52,16 @@ func TestAuthenticationReaderWriter(t *testing.T) { Content: iv, }, AdditionalDataGenerator: &NoOpBytesGenerator{}, - }, cache, NewShakeUint16Generator([]byte{'a'})) + }, PlainChunkSizeParser{}, cache) - actualPayload := make([]byte, 16*1024) - nBytes, err = reader.Read(actualPayload) + mb, err := reader.Read() assert.Error(err).IsNil() - assert.Int(nBytes).Equals(len(payload)) - assert.Bytes(actualPayload[:nBytes]).Equals(payload) + assert.Int(mb.Len()).Equals(len(rawPayload)) - _, err = reader.Read(actualPayload) - assert.Error(err).Equals(io.EOF) -} - -func TestAuthenticationReaderWriterPartial(t *testing.T) { - assert := assert.On(t) - - key := make([]byte, 16) - rand.Read(key) - block, err := aes.NewCipher(key) - assert.Error(err).IsNil() - - aead, err := cipher.NewGCM(block) - assert.Error(err).IsNil() - - payload := make([]byte, 8*1024) - rand.Read(payload) - - iv := make([]byte, 12) - rand.Read(iv) - - cache := buf.NewLocal(16 * 1024) - writer := NewAuthenticationWriter(&AEADAuthenticator{ - AEAD: aead, - NonceGenerator: &StaticBytesGenerator{ - Content: iv, - }, - AdditionalDataGenerator: &NoOpBytesGenerator{}, - }, cache, NewShakeUint16Generator([]byte{'a', 'b'})) - - writer.Write([]byte{'a', 'b', 'c', 'd'}) - - nBytes, err := writer.Write(payload) - assert.Error(err).IsNil() - assert.Int(nBytes).Equals(len(payload)) - assert.Int(cache.Len()).GreaterThan(0) - _, err = writer.Write([]byte{}) - assert.Error(err).IsNil() - - pr, pw := io.Pipe() - go func() { - pw.Write(cache.BytesTo(1024)) - time.Sleep(time.Second * 2) - pw.Write(cache.BytesRange(1024, 2048)) - time.Sleep(time.Second * 2) - pw.Write(cache.BytesRange(2048, 3072)) - time.Sleep(time.Second * 2) - pw.Write(cache.BytesFrom(3072)) - time.Sleep(time.Second * 2) - pw.Close() - }() - - reader := NewAuthenticationReader(&AEADAuthenticator{ - AEAD: aead, - NonceGenerator: &StaticBytesGenerator{ - Content: iv, - }, - AdditionalDataGenerator: &NoOpBytesGenerator{}, - }, pr, NewShakeUint16Generator([]byte{'a', 'b'})) - - actualPayload := make([]byte, 7*1024) - nBytes, err = reader.Read(actualPayload) - assert.Error(err).IsNil() - assert.Int(nBytes).Equals(4) - assert.Bytes(actualPayload[:nBytes]).Equals([]byte{'a', 'b', 'c', 'd'}) - - nBytes, err = reader.Read(actualPayload) - assert.Error(err).IsNil() - assert.Int(nBytes).Equals(len(actualPayload)) - assert.Bytes(actualPayload[:nBytes]).Equals(payload[:nBytes]) - - nBytes, err = reader.Read(actualPayload) - assert.Error(err).IsNil() - assert.Int(nBytes).Equals(len(payload) - len(actualPayload)) - assert.Bytes(actualPayload[:nBytes]).Equals(payload[7*1024:]) - - _, err = reader.Read(actualPayload) + mbContent := make([]byte, 8192) + mb.Read(mbContent) + assert.Bytes(mbContent).Equals(rawPayload) + + _, err = reader.Read() assert.Error(err).Equals(io.EOF) } diff --git a/common/crypto/chunk.go b/common/crypto/chunk.go new file mode 100644 index 000000000..fcac49238 --- /dev/null +++ b/common/crypto/chunk.go @@ -0,0 +1,150 @@ +package crypto + +import ( + "io" + + "v2ray.com/core/common/buf" + "v2ray.com/core/common/serial" +) + +type ChunkSizeDecoder interface { + SizeBytes() int + Decode([]byte) (uint16, error) +} + +type ChunkSizeEncoder interface { + SizeBytes() int + Encode(uint16, []byte) []byte +} + +type PlainChunkSizeParser struct{} + +func (PlainChunkSizeParser) SizeBytes() int { + return 2 +} + +func (PlainChunkSizeParser) Encode(size uint16, b []byte) []byte { + return serial.Uint16ToBytes(size, b) +} + +func (PlainChunkSizeParser) Decode(b []byte) (uint16, error) { + return serial.BytesToUint16(b), nil +} + +type ChunkStreamReader struct { + sizeDecoder ChunkSizeDecoder + reader buf.Reader + + buffer []byte + leftOver buf.MultiBuffer + leftOverSize uint16 +} + +func NewChunkStreamReader(sizeDecoder ChunkSizeDecoder, reader io.Reader) *ChunkStreamReader { + return &ChunkStreamReader{ + sizeDecoder: sizeDecoder, + reader: buf.NewReader(reader), + buffer: make([]byte, sizeDecoder.SizeBytes()), + } +} + +func (r *ChunkStreamReader) readAtLeast(size int) (buf.MultiBuffer, error) { + mb := r.leftOver + for mb.Len() < size { + extra, err := r.reader.Read() + if err != nil { + mb.Release() + return nil, err + } + mb.AppendMulti(extra) + } + + return mb, nil +} + +func (r *ChunkStreamReader) readSize() (uint16, error) { + if r.sizeDecoder.SizeBytes() > r.leftOver.Len() { + mb, err := r.readAtLeast(r.sizeDecoder.SizeBytes() - r.leftOver.Len()) + if err != nil { + return 0, err + } + r.leftOver.AppendMulti(mb) + } + r.leftOver.Read(r.buffer) + return r.sizeDecoder.Decode(r.buffer) +} + +func (r *ChunkStreamReader) Read() (buf.MultiBuffer, error) { + size := int(r.leftOverSize) + if size == 0 { + nextSize, err := r.readSize() + if err != nil { + return nil, err + } + if nextSize == 0 { + return nil, io.EOF + } + size = int(nextSize) + } + + leftOver := r.leftOver + if leftOver.IsEmpty() { + mb, err := r.readAtLeast(1) + if err != nil { + return nil, err + } + leftOver = mb + } + + if size >= leftOver.Len() { + r.leftOverSize = uint16(size - leftOver.Len()) + r.leftOver = nil + return leftOver, nil + } + + mb := leftOver.SliceBySize(size) + if mb.Len() != size { + b := buf.New() + b.AppendSupplier(buf.ReadFullFrom(&leftOver, size-mb.Len())) + mb.Append(b) + } + + r.leftOver = leftOver + r.leftOverSize = 0 + return mb, nil +} + +type ChunkStreamWriter struct { + sizeEncoder ChunkSizeEncoder + writer buf.Writer +} + +func NewChunkStreamWriter(sizeEncoder ChunkSizeEncoder, writer io.Writer) *ChunkStreamWriter { + return &ChunkStreamWriter{ + sizeEncoder: sizeEncoder, + writer: buf.NewWriter(writer), + } +} + +func (w *ChunkStreamWriter) Write(mb buf.MultiBuffer) error { + mb2Write := buf.NewMultiBuffer() + const sliceSize = 8192 + + for { + slice := mb.SliceBySize(sliceSize) + + b := buf.New() + b.AppendSupplier(func(buffer []byte) (int, error) { + w.sizeEncoder.Encode(uint16(slice.Len()), buffer[:0]) + return w.sizeEncoder.SizeBytes(), nil + }) + mb2Write.Append(b) + mb2Write.AppendMulti(slice) + + if mb.IsEmpty() { + break + } + } + + return w.writer.Write(mb2Write) +} diff --git a/common/crypto/chunk_test.go b/common/crypto/chunk_test.go new file mode 100644 index 000000000..e9f399b0f --- /dev/null +++ b/common/crypto/chunk_test.go @@ -0,0 +1,44 @@ +package crypto_test + +import ( + "io" + "testing" + + "v2ray.com/core/common/buf" + . "v2ray.com/core/common/crypto" + "v2ray.com/core/testing/assert" +) + +func TestChunkStreamIO(t *testing.T) { + assert := assert.On(t) + + cache := buf.NewLocal(8192) + + writer := NewChunkStreamWriter(PlainChunkSizeParser{}, cache) + reader := NewChunkStreamReader(PlainChunkSizeParser{}, cache) + + b := buf.New() + b.AppendBytes('a', 'b', 'c', 'd') + assert.Error(writer.Write(buf.NewMultiBufferValue(b))).IsNil() + + b = buf.New() + b.AppendBytes('e', 'f', 'g') + assert.Error(writer.Write(buf.NewMultiBufferValue(b))).IsNil() + + assert.Error(writer.Write(buf.NewMultiBuffer())).IsNil() + + assert.Int(cache.Len()).Equals(13) + + mb, err := reader.Read() + assert.Error(err).IsNil() + assert.Int(mb.Len()).Equals(4) + assert.Bytes(mb[0].Bytes()).Equals([]byte("abcd")) + + mb, err = reader.Read() + assert.Error(err).IsNil() + assert.Int(mb.Len()).Equals(3) + assert.Bytes(mb[0].Bytes()).Equals([]byte("efg")) + + _, err = reader.Read() + assert.Error(err).Equals(io.EOF) +} diff --git a/proxy/vmess/encoding/auth.go b/proxy/vmess/encoding/auth.go index c021dfb58..b09b2713e 100644 --- a/proxy/vmess/encoding/auth.go +++ b/proxy/vmess/encoding/auth.go @@ -4,6 +4,8 @@ import ( "crypto/md5" "hash/fnv" + "golang.org/x/crypto/sha3" + "v2ray.com/core/common/serial" ) @@ -14,26 +16,6 @@ func Authenticate(b []byte) uint32 { return fnv1hash.Sum32() } -type NoOpAuthenticator struct{} - -func (NoOpAuthenticator) NonceSize() int { - return 0 -} - -func (NoOpAuthenticator) Overhead() int { - return 0 -} - -// Seal implements AEAD.Seal(). -func (NoOpAuthenticator) Seal(dst, nonce, plaintext, additionalData []byte) []byte { - return append(dst[:0], plaintext...) -} - -// Open implements AEAD.Open(). -func (NoOpAuthenticator) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) { - return append(dst[:0], ciphertext...), nil -} - // FnvAuthenticator is an AEAD based on Fnv hash. type FnvAuthenticator struct { } @@ -71,3 +53,36 @@ func GenerateChacha20Poly1305Key(b []byte) []byte { copy(key[16:], t[:]) return key } + +type ShakeSizeParser struct { + shake sha3.ShakeHash + buffer [2]byte +} + +func NewShakeSizeParser(nonce []byte) *ShakeSizeParser { + shake := sha3.NewShake128() + shake.Write(nonce) + return &ShakeSizeParser{ + shake: shake, + } +} + +func (s *ShakeSizeParser) SizeBytes() int { + return 2 +} + +func (s *ShakeSizeParser) next() uint16 { + s.shake.Read(s.buffer[:]) + return serial.BytesToUint16(s.buffer[:]) +} + +func (s *ShakeSizeParser) Decode(b []byte) (uint16, error) { + mask := s.next() + size := serial.BytesToUint16(b) + return mask ^ size, nil +} + +func (s *ShakeSizeParser) Encode(size uint16, b []byte) []byte { + mask := s.next() + return serial.Uint16ToBytes(mask^size, b[:0]) +} diff --git a/proxy/vmess/encoding/client.go b/proxy/vmess/encoding/client.go index d57713ada..cb35568a0 100644 --- a/proxy/vmess/encoding/client.go +++ b/proxy/vmess/encoding/client.go @@ -117,23 +117,19 @@ func (v *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ } func (v *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, writer io.Writer) buf.Writer { - var authWriter io.Writer - var sizeMask crypto.Uint16Generator = crypto.StaticUint16Generator(0) + var sizeParser crypto.ChunkSizeEncoder = crypto.PlainChunkSizeParser{} if request.Option.Has(protocol.RequestOptionChunkMasking) { - sizeMask = getSizeMask(v.requestBodyIV) + sizeParser = NewShakeSizeParser(v.requestBodyIV) } if request.Security.Is(protocol.SecurityType_NONE) { if request.Option.Has(protocol.RequestOptionChunkStream) { - auth := &crypto.AEADAuthenticator{ - AEAD: NoOpAuthenticator{}, - NonceGenerator: crypto.NoOpBytesGenerator{}, - AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, - } - authWriter = crypto.NewAuthenticationWriter(auth, writer, sizeMask) - } else { - authWriter = writer + return crypto.NewChunkStreamWriter(sizeParser, writer) } - } else if request.Security.Is(protocol.SecurityType_LEGACY) { + + return buf.NewWriter(writer) + } + + if request.Security.Is(protocol.SecurityType_LEGACY) { aesStream := crypto.NewAesEncryptionStream(v.requestBodyKey, v.requestBodyIV) cryptionWriter := crypto.NewCryptionWriter(aesStream, writer) if request.Option.Has(protocol.RequestOptionChunkStream) { @@ -142,11 +138,13 @@ func (v *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write NonceGenerator: crypto.NoOpBytesGenerator{}, AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, } - authWriter = crypto.NewAuthenticationWriter(auth, cryptionWriter, sizeMask) - } else { - authWriter = cryptionWriter + return crypto.NewAuthenticationWriter(auth, sizeParser, cryptionWriter) } - } else if request.Security.Is(protocol.SecurityType_AES128_GCM) { + + return buf.NewWriter(cryptionWriter) + } + + if request.Security.Is(protocol.SecurityType_AES128_GCM) { block, _ := aes.NewCipher(v.requestBodyKey) aead, _ := cipher.NewGCM(block) @@ -158,8 +156,10 @@ func (v *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write }, AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, } - authWriter = crypto.NewAuthenticationWriter(auth, writer, sizeMask) - } else if request.Security.Is(protocol.SecurityType_CHACHA20_POLY1305) { + return crypto.NewAuthenticationWriter(auth, sizeParser, writer) + } + + if request.Security.Is(protocol.SecurityType_CHACHA20_POLY1305) { aead, _ := chacha20poly1305.New(GenerateChacha20Poly1305Key(v.requestBodyKey)) auth := &crypto.AEADAuthenticator{ @@ -170,11 +170,10 @@ func (v *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write }, AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, } - authWriter = crypto.NewAuthenticationWriter(auth, writer, sizeMask) + return crypto.NewAuthenticationWriter(auth, sizeParser, writer) } - return buf.NewWriter(authWriter) - + panic("Unknown security type.") } func (v *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.ResponseHeader, error) { @@ -216,34 +215,32 @@ func (v *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon } func (v *ClientSession) DecodeResponseBody(request *protocol.RequestHeader, reader io.Reader) buf.Reader { - var authReader io.Reader - var sizeMask crypto.Uint16Generator = crypto.StaticUint16Generator(0) + var sizeParser crypto.ChunkSizeDecoder = crypto.PlainChunkSizeParser{} if request.Option.Has(protocol.RequestOptionChunkMasking) { - sizeMask = getSizeMask(v.responseBodyIV) + sizeParser = NewShakeSizeParser(v.responseBodyIV) } if request.Security.Is(protocol.SecurityType_NONE) { if request.Option.Has(protocol.RequestOptionChunkStream) { - auth := &crypto.AEADAuthenticator{ - AEAD: NoOpAuthenticator{}, - NonceGenerator: crypto.NoOpBytesGenerator{}, - AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, - } - authReader = crypto.NewAuthenticationReader(auth, reader, sizeMask) - } else { - authReader = reader + return crypto.NewChunkStreamReader(sizeParser, reader) } - } else if request.Security.Is(protocol.SecurityType_LEGACY) { + + return buf.NewReader(reader) + } + + if request.Security.Is(protocol.SecurityType_LEGACY) { if request.Option.Has(protocol.RequestOptionChunkStream) { auth := &crypto.AEADAuthenticator{ AEAD: new(FnvAuthenticator), NonceGenerator: crypto.NoOpBytesGenerator{}, AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, } - authReader = crypto.NewAuthenticationReader(auth, v.responseReader, sizeMask) - } else { - authReader = v.responseReader + return crypto.NewAuthenticationReader(auth, sizeParser, v.responseReader) } - } else if request.Security.Is(protocol.SecurityType_AES128_GCM) { + + return buf.NewReader(v.responseReader) + } + + if request.Security.Is(protocol.SecurityType_AES128_GCM) { block, _ := aes.NewCipher(v.responseBodyKey) aead, _ := cipher.NewGCM(block) @@ -255,8 +252,10 @@ func (v *ClientSession) DecodeResponseBody(request *protocol.RequestHeader, read }, AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, } - authReader = crypto.NewAuthenticationReader(auth, reader, sizeMask) - } else if request.Security.Is(protocol.SecurityType_CHACHA20_POLY1305) { + return crypto.NewAuthenticationReader(auth, sizeParser, reader) + } + + if request.Security.Is(protocol.SecurityType_CHACHA20_POLY1305) { aead, _ := chacha20poly1305.New(GenerateChacha20Poly1305Key(v.responseBodyKey)) auth := &crypto.AEADAuthenticator{ @@ -267,10 +266,10 @@ func (v *ClientSession) DecodeResponseBody(request *protocol.RequestHeader, read }, AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, } - authReader = crypto.NewAuthenticationReader(auth, reader, sizeMask) + return crypto.NewAuthenticationReader(auth, sizeParser, reader) } - return buf.NewReader(authReader) + panic("Unknown security type.") } type ChunkNonceGenerator struct { diff --git a/proxy/vmess/encoding/server.go b/proxy/vmess/encoding/server.go index 658ae4e30..12a76ed2d 100644 --- a/proxy/vmess/encoding/server.go +++ b/proxy/vmess/encoding/server.go @@ -93,10 +93,6 @@ func (h *SessionHistory) run() { } } -func getSizeMask(b []byte) crypto.Uint16Generator { - return crypto.NewShakeUint16Generator(b) -} - type ServerSession struct { userValidator protocol.UserValidator sessionHistory *SessionHistory @@ -239,23 +235,19 @@ func (v *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request } func (v *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reader io.Reader) buf.Reader { - var authReader io.Reader - var sizeMask crypto.Uint16Generator = crypto.StaticUint16Generator(0) + var sizeParser crypto.ChunkSizeDecoder = crypto.PlainChunkSizeParser{} if request.Option.Has(protocol.RequestOptionChunkMasking) { - sizeMask = getSizeMask(v.requestBodyIV) + sizeParser = NewShakeSizeParser(v.requestBodyIV) } if request.Security.Is(protocol.SecurityType_NONE) { if request.Option.Has(protocol.RequestOptionChunkStream) { - auth := &crypto.AEADAuthenticator{ - AEAD: NoOpAuthenticator{}, - NonceGenerator: crypto.NoOpBytesGenerator{}, - AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, - } - authReader = crypto.NewAuthenticationReader(auth, reader, sizeMask) - } else { - authReader = reader + return crypto.NewChunkStreamReader(sizeParser, reader) } - } else if request.Security.Is(protocol.SecurityType_LEGACY) { + + return buf.NewReader(reader) + } + + if request.Security.Is(protocol.SecurityType_LEGACY) { aesStream := crypto.NewAesDecryptionStream(v.requestBodyKey, v.requestBodyIV) cryptionReader := crypto.NewCryptionReader(aesStream, reader) if request.Option.Has(protocol.RequestOptionChunkStream) { @@ -264,11 +256,13 @@ func (v *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reade NonceGenerator: crypto.NoOpBytesGenerator{}, AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, } - authReader = crypto.NewAuthenticationReader(auth, cryptionReader, sizeMask) - } else { - authReader = cryptionReader + return crypto.NewAuthenticationReader(auth, sizeParser, cryptionReader) } - } else if request.Security.Is(protocol.SecurityType_AES128_GCM) { + + return buf.NewReader(cryptionReader) + } + + if request.Security.Is(protocol.SecurityType_AES128_GCM) { block, _ := aes.NewCipher(v.requestBodyKey) aead, _ := cipher.NewGCM(block) @@ -280,8 +274,10 @@ func (v *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reade }, AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, } - authReader = crypto.NewAuthenticationReader(auth, reader, sizeMask) - } else if request.Security.Is(protocol.SecurityType_CHACHA20_POLY1305) { + return crypto.NewAuthenticationReader(auth, sizeParser, reader) + } + + if request.Security.Is(protocol.SecurityType_CHACHA20_POLY1305) { aead, _ := chacha20poly1305.New(GenerateChacha20Poly1305Key(v.requestBodyKey)) auth := &crypto.AEADAuthenticator{ @@ -292,10 +288,10 @@ func (v *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reade }, AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, } - authReader = crypto.NewAuthenticationReader(auth, reader, sizeMask) + return crypto.NewAuthenticationReader(auth, sizeParser, reader) } - return buf.NewReader(authReader) + panic("Unknown security type.") } func (v *ServerSession) EncodeResponseHeader(header *protocol.ResponseHeader, writer io.Writer) { @@ -316,34 +312,32 @@ func (v *ServerSession) EncodeResponseHeader(header *protocol.ResponseHeader, wr } func (v *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writer io.Writer) buf.Writer { - var authWriter io.Writer - var sizeMask crypto.Uint16Generator = crypto.StaticUint16Generator(0) + var sizeParser crypto.ChunkSizeEncoder = crypto.PlainChunkSizeParser{} if request.Option.Has(protocol.RequestOptionChunkMasking) { - sizeMask = getSizeMask(v.responseBodyIV) + sizeParser = NewShakeSizeParser(v.responseBodyIV) } if request.Security.Is(protocol.SecurityType_NONE) { if request.Option.Has(protocol.RequestOptionChunkStream) { - auth := &crypto.AEADAuthenticator{ - AEAD: NoOpAuthenticator{}, - NonceGenerator: crypto.NoOpBytesGenerator{}, - AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, - } - authWriter = crypto.NewAuthenticationWriter(auth, writer, sizeMask) - } else { - authWriter = writer + return crypto.NewChunkStreamWriter(sizeParser, writer) } - } else if request.Security.Is(protocol.SecurityType_LEGACY) { + + return buf.NewWriter(writer) + } + + if request.Security.Is(protocol.SecurityType_LEGACY) { if request.Option.Has(protocol.RequestOptionChunkStream) { auth := &crypto.AEADAuthenticator{ AEAD: new(FnvAuthenticator), NonceGenerator: crypto.NoOpBytesGenerator{}, AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, } - authWriter = crypto.NewAuthenticationWriter(auth, v.responseWriter, sizeMask) - } else { - authWriter = v.responseWriter + return crypto.NewAuthenticationWriter(auth, sizeParser, v.responseWriter) } - } else if request.Security.Is(protocol.SecurityType_AES128_GCM) { + + return buf.NewWriter(v.responseWriter) + } + + if request.Security.Is(protocol.SecurityType_AES128_GCM) { block, _ := aes.NewCipher(v.responseBodyKey) aead, _ := cipher.NewGCM(block) @@ -355,8 +349,10 @@ func (v *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writ }, AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, } - authWriter = crypto.NewAuthenticationWriter(auth, writer, sizeMask) - } else if request.Security.Is(protocol.SecurityType_CHACHA20_POLY1305) { + return crypto.NewAuthenticationWriter(auth, sizeParser, writer) + } + + if request.Security.Is(protocol.SecurityType_CHACHA20_POLY1305) { aead, _ := chacha20poly1305.New(GenerateChacha20Poly1305Key(v.responseBodyKey)) auth := &crypto.AEADAuthenticator{ @@ -367,8 +363,8 @@ func (v *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writ }, AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, } - authWriter = crypto.NewAuthenticationWriter(auth, writer, sizeMask) + return crypto.NewAuthenticationWriter(auth, sizeParser, writer) } - return buf.NewWriter(authWriter) + panic("Unknown security type.") }