diff --git a/common/buf/reader.go b/common/buf/reader.go index 833fcda51..7f74e76a0 100644 --- a/common/buf/reader.go +++ b/common/buf/reader.go @@ -98,27 +98,17 @@ func (r *BufferedReader) ReadAtMost(size int32) (MultiBuffer, error) { func (r *BufferedReader) writeToInternal(writer io.Writer) (int64, error) { mbWriter := NewWriter(writer) - totalBytes := int64(0) + var sc SizeCounter if r.Buffer != nil { - totalBytes += int64(r.Buffer.Len()) + sc.Size = int64(r.Buffer.Len()) if err := mbWriter.WriteMultiBuffer(r.Buffer); err != nil { return 0, err } r.Buffer = nil } - for { - mb, err := r.Reader.ReadMultiBuffer() - if mb != nil { - totalBytes += int64(mb.Len()) - if werr := mbWriter.WriteMultiBuffer(mb); werr != nil { - return totalBytes, err - } - } - if err != nil { - return totalBytes, err - } - } + err := Copy(r.Reader, mbWriter, CountSize(&sc)) + return sc.Size, err } // WriteTo implements io.WriterTo. diff --git a/common/buf/reader_test.go b/common/buf/reader_test.go index 836df4e52..9e3458423 100644 --- a/common/buf/reader_test.go +++ b/common/buf/reader_test.go @@ -2,8 +2,10 @@ package buf_test import ( "io" + "strings" "testing" + "v2ray.com/core/common" . "v2ray.com/core/common/buf" "v2ray.com/core/transport/pipe" . "v2ray.com/ext/assert" @@ -56,6 +58,25 @@ func TestBytesReaderMultiBuffer(t *testing.T) { assert(mb[1].String(), Equals, "efg") } +func TestReadByte(t *testing.T) { + sr := strings.NewReader("abcd") + reader := &BufferedReader{ + Reader: NewReader(sr), + } + b, err := reader.ReadByte() + common.Must(err) + if b != 'a' { + t.Error("unexpected byte: ", b, " want a") + } + + var mb MultiBuffer + nBytes, err := reader.WriteTo(&mb) + common.Must(err) + if nBytes != 3 { + t.Error("unexpect bytes written: ", nBytes) + } +} + func TestReaderInterface(t *testing.T) { _ = (io.Reader)(new(ReadVReader)) _ = (Reader)(new(ReadVReader))