diff --git a/common/buf/io.go b/common/buf/io.go index d82795cd2..a7abeff1d 100644 --- a/common/buf/io.go +++ b/common/buf/io.go @@ -57,18 +57,6 @@ func NewReader(reader io.Reader) Reader { return &BytesToBufferReader{ reader: reader, - buffer: make([]byte, 32*1024), - } -} - -func NewMergingReader(reader io.Reader) Reader { - return NewMergingReaderSize(reader, 32*1024) -} - -func NewMergingReaderSize(reader io.Reader, size uint32) Reader { - return &BytesToBufferReader{ - reader: reader, - buffer: make([]byte, size), } } diff --git a/common/buf/reader.go b/common/buf/reader.go index cade2a787..8715075a7 100644 --- a/common/buf/reader.go +++ b/common/buf/reader.go @@ -12,8 +12,30 @@ type BytesToBufferReader struct { buffer []byte } +func NewBytesToBufferReader(reader io.Reader) Reader { + return &BytesToBufferReader{ + reader: reader, + } +} + +func (r *BytesToBufferReader) readSmall() (MultiBuffer, error) { + b := New() + if err := b.Reset(ReadFrom(r.reader)); err != nil { + b.Release() + return nil, err + } + if b.IsFull() { + r.buffer = make([]byte, 32*1024) + } + return NewMultiBufferValue(b), nil +} + // Read implements Reader.Read(). func (r *BytesToBufferReader) Read() (MultiBuffer, error) { + if r.buffer == nil { + return r.readSmall() + } + nBytes, err := r.reader.Read(r.buffer) if err != nil { return nil, err diff --git a/common/buf/reader_test.go b/common/buf/reader_test.go index 7c79d0fdc..f33ed4ced 100644 --- a/common/buf/reader_test.go +++ b/common/buf/reader_test.go @@ -14,12 +14,13 @@ import ( func TestAdaptiveReader(t *testing.T) { assert := With(t) - rawContent := make([]byte, 1024*1024) - buffer := bytes.NewBuffer(rawContent) - - reader := NewReader(buffer) + reader := NewReader(bytes.NewReader(make([]byte, 1024*1024))) b, err := reader.Read() assert(err, IsNil) + assert(b.Len(), Equals, 2*1024) + + b, err = reader.Read() + assert(err, IsNil) assert(b.Len(), Equals, 32*1024) } diff --git a/transport/internet/tls/tls.go b/transport/internet/tls/tls.go index 12f5a155a..42ade5f79 100644 --- a/transport/internet/tls/tls.go +++ b/transport/internet/tls/tls.go @@ -23,7 +23,7 @@ type conn struct { func (c *conn) ReadMultiBuffer() (buf.MultiBuffer, error) { if c.mergingReader == nil { - c.mergingReader = buf.NewMergingReaderSize(c.Conn, 16*1024) + c.mergingReader = buf.NewBytesToBufferReader(c.Conn) } return c.mergingReader.Read() } diff --git a/transport/internet/websocket/connection.go b/transport/internet/websocket/connection.go index 123621aa5..bcdde3cff 100644 --- a/transport/internet/websocket/connection.go +++ b/transport/internet/websocket/connection.go @@ -49,7 +49,7 @@ func (c *connection) Read(b []byte) (int, error) { func (c *connection) ReadMultiBuffer() (buf.MultiBuffer, error) { if c.mergingReader == nil { - c.mergingReader = buf.NewMergingReader(c) + c.mergingReader = buf.NewBytesToBufferReader(c) } return c.mergingReader.Read() }