diff --git a/app/proxyman/mux/reader.go b/app/proxyman/mux/reader.go index d91a31ff1..6757ca2f4 100644 --- a/app/proxyman/mux/reader.go +++ b/app/proxyman/mux/reader.go @@ -4,6 +4,7 @@ import ( "io" "v2ray.com/core/common/buf" + "v2ray.com/core/common/crypto" "v2ray.com/core/common/serial" ) @@ -64,35 +65,7 @@ func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) { return buf.NewMultiBufferValue(b), nil } -// StreamReader reads Mux frame as a stream. -type StreamReader struct { - reader *buf.BufferedReader - leftOver int32 -} - // NewStreamReader creates a new StreamReader. -func NewStreamReader(reader *buf.BufferedReader) *StreamReader { - return &StreamReader{ - reader: reader, - leftOver: -1, - } -} - -// ReadMultiBuffer implmenets buf.Reader. -func (r *StreamReader) ReadMultiBuffer() (buf.MultiBuffer, error) { - if r.leftOver == 0 { - return nil, io.EOF - } - - if r.leftOver == -1 { - size, err := serial.ReadUint16(r.reader) - if err != nil { - return nil, err - } - r.leftOver = int32(size) - } - - mb, err := r.reader.ReadAtMost(r.leftOver) - r.leftOver -= mb.Len() - return mb, err +func NewStreamReader(reader *buf.BufferedReader) buf.Reader { + return crypto.NewChunkStreamReaderWithChunkCount(crypto.PlainChunkSizeParser{}, reader, 1) } diff --git a/common/crypto/chunk.go b/common/crypto/chunk.go index b8855fd0c..322548025 100755 --- a/common/crypto/chunk.go +++ b/common/crypto/chunk.go @@ -68,12 +68,19 @@ type ChunkStreamReader struct { buffer []byte leftOverSize int32 + maxNumChunk uint32 + numChunk uint32 } func NewChunkStreamReader(sizeDecoder ChunkSizeDecoder, reader io.Reader) *ChunkStreamReader { + return NewChunkStreamReaderWithChunkCount(sizeDecoder, reader, 0) +} + +func NewChunkStreamReaderWithChunkCount(sizeDecoder ChunkSizeDecoder, reader io.Reader, maxNumChunk uint32) *ChunkStreamReader { r := &ChunkStreamReader{ sizeDecoder: sizeDecoder, buffer: make([]byte, sizeDecoder.SizeBytes()), + maxNumChunk: maxNumChunk, } if breader, ok := reader.(*buf.BufferedReader); ok { r.reader = breader @@ -94,6 +101,10 @@ func (r *ChunkStreamReader) readSize() (uint16, error) { func (r *ChunkStreamReader) ReadMultiBuffer() (buf.MultiBuffer, error) { size := r.leftOverSize if size == 0 { + r.numChunk++ + if r.maxNumChunk > 0 && r.numChunk > r.maxNumChunk { + return nil, io.EOF + } nextSize, err := r.readSize() if err != nil { return nil, err