diff --git a/common/io/buffered_writer.go b/common/io/buffered_writer.go index d63912f06..7899c4216 100644 --- a/common/io/buffered_writer.go +++ b/common/io/buffered_writer.go @@ -17,7 +17,7 @@ type BufferedWriter struct { func NewBufferedWriter(rawWriter io.Writer) *BufferedWriter { return &BufferedWriter{ writer: rawWriter, - buffer: alloc.NewBuffer().Clear(), + buffer: alloc.NewSmallBuffer().Clear(), cached: true, } } @@ -55,11 +55,22 @@ func (v *BufferedWriter) Write(b []byte) (int, error) { if !v.cached { return v.writer.Write(b) } - nBytes, _ := v.buffer.Write(b) - if v.buffer.IsFull() { - v.FlushWithoutLock() + nBytes, err := v.buffer.Write(b) + if err != nil { + return 0, err } - return nBytes, nil + if v.buffer.IsFull() { + err := v.FlushWithoutLock() + if err != nil { + return 0, err + } + if nBytes < len(b) { + if _, err := v.writer.Write(b[nBytes:]); err != nil { + return nBytes, err + } + } + } + return len(b), nil } func (v *BufferedWriter) Flush() error { diff --git a/common/io/buffered_writer_test.go b/common/io/buffered_writer_test.go index bf999f3ce..8013de3ba 100644 --- a/common/io/buffered_writer_test.go +++ b/common/io/buffered_writer_test.go @@ -1,6 +1,7 @@ package io_test import ( + "crypto/rand" "testing" "v2ray.com/core/common/alloc" @@ -27,3 +28,26 @@ func TestBufferedWriter(t *testing.T) { writer.SetCached(false) assert.Int(content.Len()).Equals(16) } + +func TestBufferedWriterLargePayload(t *testing.T) { + assert := assert.On(t) + + content := alloc.NewLocalBuffer(128 * 1024).Clear() + + writer := NewBufferedWriter(content) + assert.Bool(writer.Cached()).IsTrue() + + payload := make([]byte, 64*1024) + rand.Read(payload) + + nBytes, err := writer.Write(payload[:1024]) + assert.Int(nBytes).Equals(1024) + assert.Error(err).IsNil() + + assert.Bool(content.IsEmpty()).IsTrue() + + nBytes, err = writer.Write(payload[1024:]) + assert.Error(err).IsNil() + assert.Int(nBytes).Equals(63 * 1024) + assert.Bytes(content.Value).Equals(payload) +}