diff --git a/common/buf/writer.go b/common/buf/writer.go index 66d24b0e2..a6d4bca00 100644 --- a/common/buf/writer.go +++ b/common/buf/writer.go @@ -3,6 +3,7 @@ package buf import ( "io" "net" + "sync" "v2ray.com/core/common" "v2ray.com/core/common/errors" @@ -56,6 +57,7 @@ func (w *BufferToBytesWriter) ReadFrom(reader io.Reader) (int64, error) { // BufferedWriter is a Writer with internal buffer. type BufferedWriter struct { + sync.Mutex writer Writer buffer *Buffer buffered bool @@ -77,6 +79,9 @@ func (w *BufferedWriter) WriteByte(c byte) error { // Write implements io.Writer. func (w *BufferedWriter) Write(b []byte) (int, error) { + w.Lock() + defer w.Unlock() + if !w.buffered { if writer, ok := w.writer.(io.Writer); ok { return writer.Write(b) @@ -95,7 +100,7 @@ func (w *BufferedWriter) Write(b []byte) (int, error) { return totalBytes, err } if !w.buffered || w.buffer.IsFull() { - if err := w.Flush(); err != nil { + if err := w.flushInternal(); err != nil { return totalBytes, err } } @@ -107,6 +112,9 @@ func (w *BufferedWriter) Write(b []byte) (int, error) { // WriteMultiBuffer implements Writer. It takes ownership of the given MultiBuffer. func (w *BufferedWriter) WriteMultiBuffer(b MultiBuffer) error { + w.Lock() + defer w.Unlock() + if !w.buffered { return w.writer.WriteMultiBuffer(b) } @@ -121,7 +129,7 @@ func (w *BufferedWriter) WriteMultiBuffer(b MultiBuffer) error { return err } if w.buffer.IsFull() { - if err := w.Flush(); err != nil { + if err := w.flushInternal(); err != nil { return err } } @@ -132,6 +140,13 @@ func (w *BufferedWriter) WriteMultiBuffer(b MultiBuffer) error { // Flush flushes buffered content into underlying writer. func (w *BufferedWriter) Flush() error { + w.Lock() + defer w.Unlock() + + return w.flushInternal() +} + +func (w *BufferedWriter) flushInternal() error { if w.buffer.IsEmpty() { return nil } @@ -150,9 +165,12 @@ func (w *BufferedWriter) Flush() error { // SetBuffered sets whether the internal buffer is used. If set to false, Flush() will be called to clear the buffer. func (w *BufferedWriter) SetBuffered(f bool) error { + w.Lock() + defer w.Unlock() + w.buffered = f if !f { - return w.Flush() + return w.flushInternal() } return nil }