diff --git a/app/dispatcher/stats_test.go b/app/dispatcher/stats_test.go index 7091e7e64..cdac8d9aa 100644 --- a/app/dispatcher/stats_test.go +++ b/app/dispatcher/stats_test.go @@ -32,12 +32,10 @@ func TestStatsWriter(t *testing.T) { Writer: buf.Discard, } - var mb buf.MultiBuffer - common.Must2(mb.Write([]byte("abcd"))) + mb := buf.MergeBytes(nil, []byte("abcd")) common.Must(writer.WriteMultiBuffer(mb)) - mb = buf.ReleaseMulti(mb) - common.Must2(mb.Write([]byte("efg"))) + mb = buf.MergeBytes(nil, []byte("efg")) common.Must(writer.WriteMultiBuffer(mb)) if c.Value() != 7 { diff --git a/app/reverse/portal.go b/app/reverse/portal.go index fb5f397ec..cc844624b 100644 --- a/app/reverse/portal.go +++ b/app/reverse/portal.go @@ -251,8 +251,7 @@ func (w *PortalWorker) heartbeat() error { b, err := proto.Marshal(msg) common.Must(err) - var mb buf.MultiBuffer - common.Must2(mb.Write(b)) + mb := buf.MergeBytes(nil, b) return w.writer.WriteMultiBuffer(mb) } diff --git a/common/buf/multi_buffer.go b/common/buf/multi_buffer.go index ce1a21cc7..c675b4090 100644 --- a/common/buf/multi_buffer.go +++ b/common/buf/multi_buffer.go @@ -8,21 +8,9 @@ import ( "v2ray.com/core/common/serial" ) -// ReadAllToMultiBuffer reads all content from the reader into a MultiBuffer, until EOF. -func ReadAllToMultiBuffer(reader io.Reader) (MultiBuffer, error) { - mb := make(MultiBuffer, 0, 128) - - if _, err := mb.ReadFrom(reader); err != nil { - ReleaseMulti(mb) - return nil, err - } - - return mb, nil -} - // ReadAllToBytes reads all content from the reader into a byte array, until EOF. func ReadAllToBytes(reader io.Reader) ([]byte, error) { - mb, err := ReadAllToMultiBuffer(reader) + mb, err := ReadFrom(reader) if err != nil { return nil, err } @@ -30,7 +18,8 @@ func ReadAllToBytes(reader io.Reader) ([]byte, error) { return nil, nil } b := make([]byte, mb.Len()) - common.Must2(mb.Read(b)) + mb, _, err = SplitBytes(mb, b) + common.Must(err) ReleaseMulti(mb) return b, nil } @@ -47,6 +36,23 @@ func MergeMulti(dest MultiBuffer, src MultiBuffer) (MultiBuffer, MultiBuffer) { return dest, src[:0] } +func MergeBytes(dest MultiBuffer, src []byte) MultiBuffer { + n := len(dest) + if n > 0 && !(dest)[n-1].IsFull() { + nBytes, _ := (dest)[n-1].Write(src) + src = src[nBytes:] + } + + for len(src) > 0 { + b := New() + nBytes, _ := b.Write(src) + src = src[nBytes:] + dest = append(dest, b) + } + + return dest +} + // ReleaseMulti release all content of the MultiBuffer, and returns an empty MultiBuffer. func ReleaseMulti(mb MultiBuffer) MultiBuffer { for i := range mb { @@ -69,93 +75,42 @@ func (mb MultiBuffer) Copy(b []byte) int { return total } -// ReadFrom implements io.ReaderFrom. -func (mb *MultiBuffer) ReadFrom(reader io.Reader) (int64, error) { - totalBytes := int64(0) - +// ReadFrom reads all content from reader until EOF. +func ReadFrom(reader io.Reader) (MultiBuffer, error) { + mb := make(MultiBuffer, 0, 16) for { b := New() _, err := b.ReadFullFrom(reader, Size) if b.IsEmpty() { b.Release() } else { - *mb = append(*mb, b) + mb = append(mb, b) } - totalBytes += int64(b.Len()) if err != nil { if errors.Cause(err) == io.EOF || errors.Cause(err) == io.ErrUnexpectedEOF { - return totalBytes, nil + return mb, nil } - return totalBytes, err + return mb, err } } } -// Read implements io.Reader. -func (mb *MultiBuffer) Read(b []byte) (int, error) { - if mb.IsEmpty() { - return 0, io.EOF - } - endIndex := len(*mb) +func SplitBytes(mb MultiBuffer, b []byte) (MultiBuffer, int, error) { totalBytes := 0 - for i, bb := range *mb { + + for len(mb) > 0 { + bb := mb[0] nBytes, _ := bb.Read(b) totalBytes += nBytes b = b[nBytes:] - if bb.IsEmpty() { - bb.Release() - (*mb)[i] = nil - } else { - endIndex = i + if !bb.IsEmpty() { break } - } - *mb = (*mb)[endIndex:] - return totalBytes, nil -} - -// WriteTo implements io.WriterTo. -func (mb *MultiBuffer) WriteTo(writer io.Writer) (int64, error) { - defer func() { - *mb = ReleaseMulti(*mb) - }() - - totalBytes := int64(0) - for _, b := range *mb { - nBytes, err := writer.Write(b.Bytes()) - totalBytes += int64(nBytes) - if err != nil { - return totalBytes, err - } + bb.Release() + mb = mb[1:] } - return totalBytes, nil -} - -// Write implements io.Writer. -func (mb *MultiBuffer) Write(b []byte) (int, error) { - totalBytes := len(b) - - n := len(*mb) - if n > 0 && !(*mb)[n-1].IsFull() { - nBytes, _ := (*mb)[n-1].Write(b) - b = b[nBytes:] - } - - for len(b) > 0 { - bb := New() - nBytes, _ := bb.Write(b) - b = b[nBytes:] - *mb = append(*mb, bb) - } - - return totalBytes, nil -} - -// WriteMultiBuffer implements Writer. -func (mb *MultiBuffer) WriteMultiBuffer(b MultiBuffer) error { - *mb, _ = MergeMulti(*mb, b) - return nil + return mb, totalBytes, nil } // Len returns the total number of bytes in the MultiBuffer. @@ -223,3 +178,39 @@ func (mb *MultiBuffer) SplitFirst() *Buffer { *mb = (*mb)[1:] return b } + +type MultiBufferContainer struct { + MultiBuffer +} + +func (c *MultiBufferContainer) Read(b []byte) (int, error) { + if c.MultiBuffer.IsEmpty() { + return 0, io.EOF + } + + mb, nBytes, err := SplitBytes(c.MultiBuffer, b) + c.MultiBuffer = mb + return nBytes, err +} + +func (c *MultiBufferContainer) ReadMultiBuffer() (MultiBuffer, error) { + mb := c.MultiBuffer + c.MultiBuffer = nil + return mb, nil +} + +func (c *MultiBufferContainer) Write(b []byte) (int, error) { + c.MultiBuffer = MergeBytes(c.MultiBuffer, b) + return len(b), nil +} + +func (c *MultiBufferContainer) WriteMultiBuffer(b MultiBuffer) error { + mb, _ := MergeMulti(c.MultiBuffer, b) + c.MultiBuffer = mb + return nil +} + +func (c *MultiBufferContainer) Close() error { + c.MultiBuffer = ReleaseMulti(c.MultiBuffer) + return nil +} diff --git a/common/buf/multi_buffer_test.go b/common/buf/multi_buffer_test.go index 2d329328c..9bec3a4ca 100644 --- a/common/buf/multi_buffer_test.go +++ b/common/buf/multi_buffer_test.go @@ -21,7 +21,7 @@ func TestMultiBufferRead(t *testing.T) { mb := MultiBuffer{b1, b2} bs := make([]byte, 32) - nBytes, err := mb.Read(bs) + _, nBytes, err := SplitBytes(mb, bs) assert(err, IsNil) assert(nBytes, Equals, 4) assert(bs[:nBytes], Equals, []byte("abcd")) @@ -43,16 +43,8 @@ func TestMultiBufferSliceBySizeLarge(t *testing.T) { lb := make([]byte, 8*1024) common.Must2(io.ReadFull(rand.Reader, lb)) - var mb MultiBuffer - common.Must2(mb.Write(lb)) + mb := MergeBytes(nil, lb) mb2 := mb.SliceBySize(1024) assert(mb2.Len(), Equals, int32(1024)) } - -func TestInterface(t *testing.T) { - assert := With(t) - - assert((*MultiBuffer)(nil), Implements, (*io.WriterTo)(nil)) - assert((*MultiBuffer)(nil), Implements, (*io.ReaderFrom)(nil)) -} diff --git a/common/buf/reader.go b/common/buf/reader.go index 90baa48cb..39503beb9 100644 --- a/common/buf/reader.go +++ b/common/buf/reader.go @@ -46,8 +46,9 @@ func (r *BufferedReader) ReadByte() (byte, error) { // Read implements io.Reader. It reads from internal buffer first (if available) and then reads from the underlying reader. func (r *BufferedReader) Read(b []byte) (int, error) { if !r.Buffer.IsEmpty() { - nBytes, err := r.Buffer.Read(b) + buffer, nBytes, err := SplitBytes(r.Buffer, b) common.Must(err) + r.Buffer = buffer if r.Buffer.IsEmpty() { r.Buffer = nil } @@ -59,12 +60,12 @@ func (r *BufferedReader) Read(b []byte) (int, error) { return 0, err } - nBytes, err := mb.Read(b) + mb, nBytes, err := SplitBytes(mb, b) common.Must(err) if !mb.IsEmpty() { r.Buffer = mb } - return nBytes, err + return nBytes, nil } // ReadMultiBuffer implements Reader. diff --git a/common/buf/reader_test.go b/common/buf/reader_test.go index dddc87d64..c9f079b7b 100644 --- a/common/buf/reader_test.go +++ b/common/buf/reader_test.go @@ -69,8 +69,7 @@ func TestReadByte(t *testing.T) { t.Error("unexpected byte: ", b, " want a") } - var mb MultiBuffer - nBytes, err := reader.WriteTo(&mb) + nBytes, err := reader.WriteTo(DiscardBytes) common.Must(err) if nBytes != 3 { t.Error("unexpect bytes written: ", nBytes) diff --git a/common/buf/readv_test.go b/common/buf/readv_test.go index 62cb539fe..7feef9989 100644 --- a/common/buf/readv_test.go +++ b/common/buf/readv_test.go @@ -33,8 +33,7 @@ func TestReadvReader(t *testing.T) { go func() { writer := NewWriter(conn) - var mb MultiBuffer - common.Must2(mb.Write(data)) + mb := MergeBytes(nil, data) if err := writer.WriteMultiBuffer(mb); err != nil { t.Fatal("failed to write data: ", err) @@ -58,7 +57,8 @@ func TestReadvReader(t *testing.T) { } rdata := make([]byte, size) - common.Must2(rmb.Read(rdata)) + _, _, err = SplitBytes(rmb, rdata) + common.Must(err) if err := compare.BytesEqualWithDetail(data, rdata); err != nil { t.Fatal(err) diff --git a/common/buf/writer.go b/common/buf/writer.go index b6d5a15b3..5e513fef7 100644 --- a/common/buf/writer.go +++ b/common/buf/writer.go @@ -134,15 +134,16 @@ func (w *BufferedWriter) WriteMultiBuffer(b MultiBuffer) error { return w.writer.WriteMultiBuffer(b) } - defer ReleaseMulti(b) + reader := MultiBufferContainer{ + MultiBuffer: b, + } + defer reader.Close() - for !b.IsEmpty() { + for !reader.MultiBuffer.IsEmpty() { if w.buffer == nil { w.buffer = New() } - if _, err := w.buffer.ReadFrom(&b); err != nil { - return err - } + common.Must2(w.buffer.ReadFrom(&reader)) if w.buffer.IsFull() { if err := w.flushInternal(); err != nil { return err diff --git a/common/crypto/auth.go b/common/crypto/auth.go index 35147200a..d400c6ebe 100644 --- a/common/crypto/auth.go +++ b/common/crypto/auth.go @@ -194,7 +194,7 @@ func (r *AuthenticationReader) readInternal(soft bool, mb *buf.MultiBuffer) erro return err } - common.Must2(mb.Write(rb)) + *mb = buf.MergeBytes(*mb, rb) return nil } @@ -279,11 +279,17 @@ func (w *AuthenticationWriter) writeStream(mb buf.MultiBuffer) error { payloadSize := buf.Size - int32(w.auth.Overhead()) - w.sizeParser.SizeBytes() - maxPadding mb2Write := make(buf.MultiBuffer, 0, len(mb)+10) + temp := buf.New() + defer temp.Release() + + rawBytes := temp.Extend(payloadSize) + for { - b := buf.New() - common.Must2(b.ReadFrom(io.LimitReader(&mb, int64(payloadSize)))) - eb, err := w.seal(b.Bytes()) - b.Release() + nb, nBytes, err := buf.SplitBytes(mb, rawBytes) + common.Must(err) + mb = nb + + eb, err := w.seal(rawBytes[:nBytes]) if err != nil { buf.ReleaseMulti(mb2Write) diff --git a/common/crypto/auth_test.go b/common/crypto/auth_test.go index b53c99f3e..b79a15be6 100644 --- a/common/crypto/auth_test.go +++ b/common/crypto/auth_test.go @@ -30,8 +30,7 @@ func TestAuthenticationReaderWriter(t *testing.T) { rawPayload := make([]byte, payloadSize) rand.Read(rawPayload) - var payload buf.MultiBuffer - payload.Write(rawPayload) + payload := buf.MergeBytes(nil, rawPayload) assert(payload.Len(), Equals, int32(payloadSize)) cache := bytes.NewBuffer(nil) @@ -66,7 +65,7 @@ func TestAuthenticationReaderWriter(t *testing.T) { assert(mb.Len(), Equals, int32(payloadSize)) mbContent := make([]byte, payloadSize) - mb.Read(mbContent) + buf.SplitBytes(mb, mbContent) assert(mbContent, Equals, rawPayload) _, err = reader.ReadMultiBuffer() diff --git a/common/net/connection.go b/common/net/connection.go index c1341bd86..2678cbda2 100644 --- a/common/net/connection.go +++ b/common/net/connection.go @@ -100,7 +100,7 @@ func (c *connection) Write(b []byte) (int, error) { l := len(b) mb := make(buf.MultiBuffer, 0, l/buf.Size+1) - common.Must2(mb.Write(b)) + mb = buf.MergeBytes(mb, b) return l, c.writer.WriteMultiBuffer(mb) } diff --git a/common/platform/ctlcmd/ctlcmd.go b/common/platform/ctlcmd/ctlcmd.go index 459329537..ed770b555 100644 --- a/common/platform/ctlcmd/ctlcmd.go +++ b/common/platform/ctlcmd/ctlcmd.go @@ -17,8 +17,8 @@ func Run(args []string, input io.Reader) (buf.MultiBuffer, error) { return nil, newError("v2ctl doesn't exist").Base(err) } - errBuffer := buf.MultiBuffer{} - outBuffer := buf.MultiBuffer{} + var errBuffer buf.MultiBufferContainer + var outBuffer buf.MultiBufferContainer cmd := exec.Command(v2ctl, args...) cmd.Stderr = &errBuffer @@ -35,12 +35,10 @@ func Run(args []string, input io.Reader) (buf.MultiBuffer, error) { if err := cmd.Wait(); err != nil { msg := "failed to execute v2ctl" if errBuffer.Len() > 0 { - msg += ": " + errBuffer.String() + msg += ": " + errBuffer.MultiBuffer.String() } - buf.ReleaseMulti(errBuffer) - buf.ReleaseMulti(outBuffer) return nil, newError(msg).Base(err) } - return outBuffer, nil + return outBuffer.MultiBuffer, nil } diff --git a/main/confloader/external/external.go b/main/confloader/external/external.go index ef5611fe9..75dcdea37 100644 --- a/main/confloader/external/external.go +++ b/main/confloader/external/external.go @@ -12,16 +12,6 @@ import ( //go:generate errorgen -type ClosableMultiBuffer struct { - buf.MultiBuffer -} - -func (c *ClosableMultiBuffer) Close() error { - buf.ReleaseMulti(c.MultiBuffer) - c.MultiBuffer = nil - return nil -} - func loadConfigFile(configFile string) (io.ReadCloser, error) { if configFile == "stdin:" { return os.Stdin, nil @@ -32,7 +22,9 @@ func loadConfigFile(configFile string) (io.ReadCloser, error) { if err != nil { return nil, err } - return &ClosableMultiBuffer{content}, nil + return &buf.MultiBufferContainer{ + MultiBuffer: content, + }, nil } fixedFile := os.ExpandEnv(configFile) @@ -42,12 +34,13 @@ func loadConfigFile(configFile string) (io.ReadCloser, error) { } defer file.Close() - content, err := buf.ReadAllToMultiBuffer(file) + content, err := buf.ReadFrom(file) if err != nil { return nil, newError("failed to load config file: ", fixedFile).Base(err).AtWarning() } - return &ClosableMultiBuffer{content}, nil - + return &buf.MultiBufferContainer{ + MultiBuffer: content, + }, nil } func init() { diff --git a/main/json/config_json.go b/main/json/config_json.go index 81ac55f40..45c6374d9 100644 --- a/main/json/config_json.go +++ b/main/json/config_json.go @@ -7,6 +7,7 @@ import ( "v2ray.com/core" "v2ray.com/core/common" + "v2ray.com/core/common/buf" "v2ray.com/core/common/platform/ctlcmd" ) @@ -19,7 +20,9 @@ func init() { if err != nil { return nil, newError("failed to execute v2ctl to convert config file.").Base(err).AtWarning() } - return core.LoadConfig("protobuf", "", &jsonContent) + return core.LoadConfig("protobuf", "", &buf.MultiBufferContainer{ + MultiBuffer: jsonContent, + }) }, })) } diff --git a/proxy/http/server.go b/proxy/http/server.go index 67262e1e0..5c298e649 100755 --- a/proxy/http/server.go +++ b/proxy/http/server.go @@ -185,8 +185,7 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade } if reader.Buffered() > 0 { - var payload buf.MultiBuffer - _, err := payload.ReadFrom(&io.LimitedReader{R: reader, N: int64(reader.Buffered())}) + payload, err := buf.ReadFrom(io.LimitReader(reader, int64(reader.Buffered()))) if err != nil { return err } diff --git a/proxy/shadowsocks/ota.go b/proxy/shadowsocks/ota.go index ab0d3b597..10507cc8d 100644 --- a/proxy/shadowsocks/ota.go +++ b/proxy/shadowsocks/ota.go @@ -92,8 +92,7 @@ func (v *ChunkReader) ReadMultiBuffer() (buf.MultiBuffer, error) { return nil, newError("invalid auth") } - var mb buf.MultiBuffer - common.Must2(mb.Write(payload)) + mb := buf.MergeBytes(nil, payload) return mb, nil } @@ -117,7 +116,7 @@ func (w *ChunkWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { defer buf.ReleaseMulti(mb) for { - payloadLen, _ := mb.Read(w.buffer[2+AuthSize:]) + mb, payloadLen, _ := buf.SplitBytes(mb, w.buffer[2+AuthSize:]) binary.BigEndian.PutUint16(w.buffer, uint16(payloadLen)) w.auth.Authenticate(w.buffer[2+AuthSize:2+AuthSize+payloadLen], w.buffer[2:]) if err := buf.WriteAllBytes(w.writer, w.buffer[:2+AuthSize+payloadLen]); err != nil { diff --git a/transport/internet/kcp/connection.go b/transport/internet/kcp/connection.go index 31494006c..df6f90086 100644 --- a/transport/internet/kcp/connection.go +++ b/transport/internet/kcp/connection.go @@ -1,6 +1,7 @@ package kcp import ( + "bytes" "io" "net" "runtime" @@ -8,7 +9,6 @@ import ( "sync/atomic" "time" - "v2ray.com/core/common" "v2ray.com/core/common/buf" "v2ray.com/core/common/signal" "v2ray.com/core/common/signal/semaphore" @@ -364,12 +364,8 @@ func (c *Connection) waitForDataOutput() error { // Write implements io.Writer. func (c *Connection) Write(b []byte) (int, error) { - // This involves multiple copies of the buffer. But we don't expect this method to be used often. - // Only wrapped connections such as TLS and WebSocket will call into this. - // TODO: improve efficiency. - var mb buf.MultiBuffer - common.Must2(mb.Write(b)) - if err := c.WriteMultiBuffer(mb); err != nil { + reader := bytes.NewReader(b) + if err := c.writeMultiBufferInternal(reader); err != nil { return 0, err } return len(b), nil @@ -377,8 +373,15 @@ func (c *Connection) Write(b []byte) (int, error) { // WriteMultiBuffer implements buf.Writer. func (c *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error { - defer buf.ReleaseMulti(mb) + reader := &buf.MultiBufferContainer{ + MultiBuffer: mb, + } + defer reader.Close() + return c.writeMultiBufferInternal(reader) +} + +func (c *Connection) writeMultiBufferInternal(reader io.Reader) error { updatePending := false defer func() { if updatePending { @@ -386,19 +389,28 @@ func (c *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error { } }() + var b *buf.Buffer + defer b.Release() + for { for { if c == nil || c.State() != StateActive { return io.ErrClosedPipe } - if !c.sendingWorker.Push(&mb) { + if b == nil { + b = buf.New() + _, err := b.ReadFrom(io.LimitReader(reader, int64(c.mss))) + if err != nil { + return nil + } + } + + if !c.sendingWorker.Push(b) { break } updatePending = true - if mb.IsEmpty() { - return nil - } + b = nil } if updatePending { diff --git a/transport/internet/kcp/receiving.go b/transport/internet/kcp/receiving.go index 86c389e14..0e6ea908a 100644 --- a/transport/internet/kcp/receiving.go +++ b/transport/internet/kcp/receiving.go @@ -209,7 +209,7 @@ func (w *ReceivingWorker) Read(b []byte) int { if mb.IsEmpty() { return 0 } - nBytes, err := mb.Read(b) + mb, nBytes, err := buf.SplitBytes(mb, b) common.Must(err) if !mb.IsEmpty() { w.leftOver = mb diff --git a/transport/internet/kcp/sending.go b/transport/internet/kcp/sending.go index c734ea364..23d9ee09b 100644 --- a/transport/internet/kcp/sending.go +++ b/transport/internet/kcp/sending.go @@ -2,10 +2,8 @@ package kcp import ( "container/list" - "io" "sync" - "v2ray.com/core/common" "v2ray.com/core/common/buf" ) @@ -262,7 +260,7 @@ func (w *SendingWorker) ProcessSegment(current uint32, seg *AckSegment, rto uint } } -func (w *SendingWorker) Push(mb *buf.MultiBuffer) bool { +func (w *SendingWorker) Push(b *buf.Buffer) bool { w.Lock() defer w.Unlock() @@ -274,8 +272,6 @@ func (w *SendingWorker) Push(mb *buf.MultiBuffer) bool { return false } - b := buf.New() - common.Must2(b.ReadFrom(io.LimitReader(mb, int64(w.conn.mss)))) w.window.Push(w.nextNumber, b) w.nextNumber++ return true